1################################################################## 2## (c) Copyright 2015- by Jaron T. Krogel ## 3################################################################## 4 5 6#====================================================================# 7# structure.py # 8# Support for atomic structure I/O, generation, and manipulation. # 9# # 10# Content summary: # 11# Structure # 12# Represents a simulation cell containing a set of atoms. # 13# Many functions for manipulating structures or obtaining # 14# data regarding local atomic structure. # 15# # 16# generate_cell # 17# User-facing function to generate an empty simulation cell. # 18# # 19# generate_structure # 20# User-facing function to specify arbitrary atomic structures # 21# or generate structures corresponding to atoms, dimers, or # 22# crystals. # 23# # 24#====================================================================# 25 26""" 27The :py:mod:`structure` module provides support for atomic structure I/O, 28generation, and manipulation. 29 30 31List of module contents 32----------------------- 33 34Read cif file functions: 35 36* :py:func:`read_cif_celldata` 37* :py:func:`read_cif_cell` 38* :py:func:`read_cif` 39 40Operations on logical conditions: 41 42* :py:func:`equate` 43* :py:func:`negate` 44 45Create a Monkhorst-Pack k-point mesh function 46 47* :py:func:`kmesh` 48 49Tile matrix malipulation functions 50 51* :py:func:`reduce_tilematrix` 52* :py:func:`tile_magnetization` 53 54Rotate plane function 55 56* :py:func:`rotate_plane` 57 58Trivial filter function 59 60* :py:func:`trivial_filter` 61 62* :py:class:`MaskFilter` 63 64* :py:func:`optimal_tilematrix` 65 66Base class for :py:class:`Structure` class: 67 68* :py:class:`Sobj` 69 70Base class for :py:class:`DefectStructure`, :py:class:`Crystal`, and :py:class:`Jellium` classes: 71 72* :py:class:`Structure` 73 74SeeK-path functions 75 76* :py:func:`\_getseekpath` 77* :py:func:`get_conventional_cell` 78* :py:func:`get_primitive_cell` 79* :py:func:`get_kpath` 80* :py:func:`get_symmetry` 81* :py:func:`get_structure_with_bands` 82* :py:func:`get_band_tiling` 83* :py:func:`get_seekpath_full` 84 85Interpolate structures functions 86 87* :py:func:`interpolate_structures` 88 89Animate structures functions 90 91* :py:func:`structure_animation` 92 93Concrete :py:class:`Structure` classes: 94 95* :py:class:`DefectStructure` 96* :py:class:`Crystal` 97* :py:class:`Jellium` 98 99Structure generation functions: 100 101* :py:func:`generate_cell` 102* :py:func:`generate_structure` 103* :py:func:`generate_atom_structure` 104* :py:func:`generate_dimer_structure` 105* :py:func:`generate_trimer_structure` 106* :py:func:`generate_jellium_structure` 107* :py:func:`generate_crystal_structure` 108* :py:func:`generate_defect_structure` 109 110Read structure functions 111 112* :py:func:`read_structure` 113 114 115 116Module contents 117--------------- 118""" 119 120import os 121import numpy as np 122from copy import deepcopy 123from random import randint 124from numpy import abs,all,append,arange,around,array,atleast_2d,ceil,cos,cross,cross,diag,dot,empty,exp,flipud,floor,identity,isclose,logical_not,mgrid,mod,ndarray,ones,pi,round,sign,sin,sqrt,uint64,zeros 125from numpy.linalg import inv,det,norm 126from unit_converter import convert 127from numerics import nearest_neighbors,convex_hull,voronoi_neighbors 128from periodic_table import pt,is_element 129from fileio import XsfFile,PoscarFile 130from generic import obj 131from developer import DevBase,unavailable,error,warn 132from debug import ci,ls,gs 133 134try: 135 from scipy.special import erfc 136except: 137 erfc = unavailable('scipy.special','erfc') 138#end try 139try: 140 import matplotlib.pyplot as plt 141 from matplotlib.pyplot import plot,subplot,title,xlabel,ylabel 142except: 143 plot,subplot,title,xlabel,ylabel,plt = unavailable('matplotlib.pyplot','plot','subplot','title','xlabel','ylabel','plt') 144#end try 145 146 147 148 149 150# installation instructions to enable cif file read 151# 152# cif file support in Nexus currently requires two external libraries 153# PyCifRW - base interface to read cif files into object format: CifFile 154# cif2cell - translation layer from CifFile object to cell reconstruction: CellData 155# (note: cif2cell installation includes PyCifRW) 156# 157# installation of cif2cell 158# go to http://sourceforge.net/projects/cif2cell/ 159# click on Download (example: cif2cell-1.2.10.tar.gz) 160# unpack directory (tar -xzf cif2cell-1.2.10.tar.gz) 161# enter directory (cd cif2cell-1.2.10) 162# install cif2cell (python setup.py install) 163# check python installation 164# >python 165# >>>from CifFile import CifFile 166# >>>from uctools import CellData 167# 168# Nexus is currently compatible with 169# cif2cell-1.2.10 and PyCifRW-3.3 170# cif2cell-1.2.7 and PyCifRW-4.1.1 171# compatibility last tested: 20 Mar 2017 172# 173try: 174 from CifFile import CifFile 175except: 176 CifFile = unavailable('CifFile','CifFile') 177#end try 178try: 179 from cif2cell.uctools import CellData 180except: 181 CellData = unavailable('cif2cell.uctools','CellData') 182#end try 183 184 185cif2cell_unit_dict = dict(angstrom='A',bohr='B',nm='nm') 186 187 188 189def read_cif_celldata(filepath,block=None,grammar='1.1'): 190 # read cif file with PyCifRW 191 path,cif_file = os.path.split(filepath) 192 if path!='': 193 cwd = os.getcwd() 194 os.chdir(path) 195 #end if 196 cf = CifFile(cif_file,grammar=grammar) 197 #cf = ReadCif(cif_file,grammar=grammar) 198 if path!='': 199 os.chdir(cwd) 200 #end if 201 if block is None: 202 block = list(cf.keys())[0] 203 #end if 204 cb = cf.get(block) 205 if cb is None: 206 error('block {0} was not found in cif file {1}'.format(block,filepath),'read_cif_celldata') 207 #end if 208 209 # repack H-M symbols as normal strings so CellData.getFromCIF won't choke on unicode 210 #for k in ['_symmetry_space_group_name_H-M','_space_group_name_H-M_alt','_symmetry_space_group_name_h-m','_space_group_name_h-m_alt']: 211 # if k in cb.block: 212 # v = cb.block[k] 213 # if isinstance(v,(list,tuple)): 214 # for i in range(len(v)): 215 # if isinstance(v[i],unicode): 216 # v[i] = str(v[i]) 217 # #end if 218 # #end for 219 # #end if 220 # #end if 221 ##end for 222 223 # extract structure from CifFile with uctools CellData class 224 cd = CellData() 225 cd.getFromCIF(cb) 226 227 return cd 228#end def read_cif_celldata 229 230 231 232def read_cif_cell(filepath,block=None,grammar='1.1',cell='prim'): 233 cd = read_cif_celldata(filepath,block,grammar) 234 235 if cell.startswith('prim'): 236 cell = cd.primitive() 237 elif cell.startswith('conv'): 238 cell = cd.conventional() 239 else: 240 error('cell argument must be primitive or conventional\nyou provided: {0}'.format(cell),'read_cif_cell') 241 #end if 242 243 return cell 244#end def read_cif_cell 245 246 247 248def read_cif(filepath,block=None,grammar='1.1',cell='prim',args_only=False): 249 if isinstance(filepath,str): 250 cell = read_cif_cell(filepath,block,grammar,cell) 251 else: 252 cell = filepath 253 #end if 254 255 # create Structure object from cell 256 if cell.alloy: 257 error('cannot handle alloys','read_cif') 258 #end if 259 units = cif2cell_unit_dict[cell.unit] 260 scale = float(cell.lengthscale) 261 scale = convert(scale,units,'A') 262 units = 'A' 263 axes = scale*array(cell.latticevectors,dtype=float) 264 elem = [] 265 pos = [] 266 for wyckoff_atoms in cell.atomdata: 267 for atom in wyckoff_atoms: 268 elem.append(str(list(atom.species.keys())[0])) 269 pos.append(atom.position) 270 #end for 271 #end for 272 pos = dot(array(pos,dtype=float),axes) 273 274 if not args_only: 275 s = Structure( 276 axes = axes, 277 elem = elem, 278 pos = pos, 279 units = units 280 ) 281 return s 282 else: 283 return axes,elem,pos,units 284 #end if 285#end def read_cif 286 287 288 289 290 291# installation instructions for spglib interface 292# 293# this is bootstrapped off of spglib's ASE Python interface 294# 295# installation of spglib 296# go to http://sourceforge.net/projects/spglib/files/ 297# click on Download spglib-1.8.2.tar.gz (952.6 kB) 298# unpack directory (tar -xzf spglib-1.8.2.tar.gz) 299# enter ase directory (cd spglib-1.8.2/python/ase/) 300# build and install (sudo python setup.py install) 301from periodic_table import pt as ptable 302#try: 303# from pyspglib import spglib 304#except: 305# spglib = unavailable('pyspglib','spglib') 306##end try 307try: 308 import spglib 309except: 310 spglib = unavailable('spglib') 311#end try 312 313 314 315 316 317def equate(expr): 318 return expr 319#end def equate 320 321def negate(expr): 322 return not expr 323#end def negate 324 325 326 327def kmesh(kaxes,dim,shift=None): 328 ''' 329 Create a Monkhorst-Pack k-point mesh 330 ''' 331 if shift is None: 332 shift = (0.,0,0) 333 #end if 334 ndim = len(dim) 335 d = array(dim) 336 s = array(shift) 337 s.shape = 1,ndim 338 d.shape = 1,ndim 339 kp = empty((1,ndim),dtype=float) 340 kgrid = empty((d.prod(),ndim)) 341 n=0 342 for k in range(dim[2]): 343 for j in range(dim[1]): 344 for i in range(dim[0]): 345 kp[:] = i,j,k 346 kp = dot((kp+s)/d,kaxes) 347 #kp = (kp+s)/d 348 kgrid[n] = kp 349 n+=1 350 #end for 351 #end for 352 #end for 353 return kgrid 354#end def kmesh 355 356 357 358def reduce_tilematrix(tiling): 359 tiling = array(tiling) 360 t = array(tiling,dtype=int) 361 if abs(tiling-t).sum()>1e-6: 362 Structure.class_error('requested tiling is non-integer\n tiling requested: '+str(tiling)) 363 #end if 364 365 dim = len(t) 366 matrix_tiling = t.shape == (dim,dim) 367 if matrix_tiling: 368 if abs(det(t))==0: 369 Structure.class_error('requested tiling matrix is singular\ntiling requested: {0}'.format(t)) 370 #end if 371 #find a tiling tuple from the tiling matrix 372 # do this by shearing the tiling matrix (or equivalently the tiled cell) 373 # until it is orthogonal (in the untiled cell axes) 374 # this is just rearranging triangular tiles of space to reshape the cell 375 # so that t1*t2*t3 = det(T) = det(A_tiled)/det(A_untiled) 376 #this way the atoms in the (perhaps oddly shaped) supercell can be 377 # obtained from simple translations of the untiled cell positions 378 T = t #tiling matrix 379 tilematrix = T.copy() 380 del t 381 tbar = identity(dim) #basis for shearing 382 dr = list(range(dim)) 383 #dr = [1,0,2] 384 other = dim*[0] # other[d] = dimensions other than d 385 for d in dr: 386 other[d] = set(dr)-set([d]) 387 #end for 388 #move each axis to be parallel to barred directions 389 # these are volume preserving shears of the supercell 390 # each shear keeps two cell face planes fixed while moving the others 391 tvecs = [] 392 for dp in [(0,1,2),(2,0,1),(1,2,0),(2,1,0),(0,2,1),(1,0,2)]: 393 success = True 394 Tnew = array(T,dtype=float) #sheared/orthogonal tiling matrix 395 for d in dr: 396 tb = tbar[dp[d]] 397 t = T[d] 398 d2,d3 = other[d] 399 n = cross(Tnew[d2],Tnew[d3]) #vector normal to 2 cell faces 400 vol = dot(n,t) 401 bcomp = dot(n,tb) 402 if abs(bcomp)<1e-6: 403 success = False 404 break 405 #end if 406 tn = vol*1./bcomp*tb #new axis vector 407 Tnew[d] = tn 408 #end for 409 if success: 410 # apply inverse permutation, if needed 411 Tn = Tnew.copy() 412 for d in dr: 413 d2 = dp[d] 414 Tnew[d2] = Tn[d] 415 #end for 416 #the resulting tiling matrix should be diagonal and integer 417 tr = diag(Tnew) 418 nondiagonal = abs(Tnew-diag(tr)).sum()>1e-6 419 if nondiagonal: 420 Structure.class_error('could not find a diagonal tiling matrix for generating tiled coordinates') 421 #end if 422 tvecs.append(abs(tr)) 423 #end if 424 #end for 425 tvecs_old = tvecs 426 tvecs = [] 427 tvset = set() 428 for tv in tvecs_old: 429 tvk = tuple(array(around(1e7*tv),dtype=uint64)) 430 if tvk not in tvset: 431 tvset.add(tvk) 432 tvecs.append(tv) 433 #end if 434 #end for 435 tilevector = array(tvecs) 436 else: 437 tilevector = t 438 tilematrix = diag(t) 439 #end if 440 441 return tilematrix,tilevector 442#end def reduce_tilematrix 443 444 445 446def rotate_plane(plane,angle,points,units='degrees'): 447 if units=='degrees': 448 angle *= pi/180 449 elif not units.startswith('rad'): 450 error('angular units must be degrees or radians\nyou provided: {0}'.format(angle),'rotate_plane') 451 #end if 452 c = cos(angle) 453 s = sin(angle) 454 if plane=='xy': 455 R = [[ c,-s, 0], 456 [ s, c, 0], 457 [ 0, 0, 1]] 458 elif plane=='yx': 459 R = [[ c, s, 0], 460 [-s, c, 0], 461 [ 0, 0, 1]] 462 elif plane=='yz': 463 R = [[ 1, 0, 0], 464 [ 0, c,-s], 465 [ 0, s, c]] 466 elif plane=='zy': 467 R = [[ 1, 0, 0], 468 [ 0, c, s], 469 [ 0,-s, c]] 470 elif plane=='zx': 471 R = [[ c, 0, s], 472 [ 0, 1, 0], 473 [-s, 0, c]] 474 elif plane=='xz': 475 R = [[ c, 0,-s], 476 [ 0, 1, 0], 477 [ s, 0, c]] 478 else: 479 error('plane must be xy/yx/yz/zy/zx/xz\nyou provided: {0}'.format(plane),'rotate_plane') 480 #end if 481 R = array(R,dtype=float) 482 return dot(R,points.T).T 483#end def rotate_plane 484 485 486 487opt_tm_matrices = obj() 488opt_tm_wig_indices = obj() 489 490def trivial_filter(T): 491 return True 492#end def trival_filter 493 494class MaskFilter(DevBase): 495 def set(self,mask,dim=3): 496 omask = array(mask) 497 mask = array(mask,dtype=bool) 498 if mask.size==dim: 499 mvec = mask.ravel() 500 mask = empty((dim,dim),dtype=bool) 501 i=0 502 for mi in mvec: 503 j=0 504 for mj in mvec: 505 mask[i,j] = mi==mj 506 j+=1 507 #end for 508 i+=1 509 #end for 510 elif mask.shape!=(dim,dim): 511 error('shape of mask array must be {0},{0}\nshape received: {1},{2}\nmask array received: {3}'.format(dim,mask.shape[0],mask.shape[1],omask),'optimal_tilematrix') 512 #end if 513 self.mask = mask==False 514 #end def set 515 516 def __call__(self,T): 517 return (T[self.mask]==0).all() 518 #end def __call__ 519#end class MaskFilter 520mask_filter = MaskFilter() 521 522 523def optimal_tilematrix(axes,volfac,dn=1,tol=1e-3,filter=trivial_filter,mask=None,nc=5): 524 if mask is not None: 525 mask_filter.set(mask) 526 filter = mask_filter 527 #end if 528 dim = 3 529 if isinstance(axes,Structure): 530 axes = axes.axes 531 else: 532 axes = array(axes,dtype=float) 533 #end if 534 if not isinstance(volfac,int): 535 volfac = int(around(volfac)) 536 #end if 537 volume = abs(det(axes))*volfac 538 axinv = inv(axes) 539 cube = volume**(1./3)*identity(dim) 540 Tref = array(around(dot(cube,axinv)),dtype=int) 541 # calculate and store all tiling matrix variations 542 if dn not in opt_tm_matrices: 543 mats = [] 544 rng = tuple(range(-dn,dn+1)) 545 for n1 in rng: 546 for n2 in rng: 547 for n3 in rng: 548 for n4 in rng: 549 for n5 in rng: 550 for n6 in rng: 551 for n7 in rng: 552 for n8 in rng: 553 for n9 in rng: 554 mats.append((n1,n2,n3,n4,n5,n6,n7,n8,n9)) 555 #end for 556 #end for 557 #end for 558 #end for 559 #end for 560 #end for 561 #end for 562 #end for 563 #end for 564 mats = array(mats,dtype=int) 565 mats.shape = (2*dn+1)**(dim*dim),dim,dim 566 opt_tm_matrices[dn] = mats 567 else: 568 mats = opt_tm_matrices[dn] 569 #end if 570 # calculate and store all wigner image indices 571 if nc not in opt_tm_wig_indices: 572 inds = [] 573 rng = tuple(range(-nc,nc+1)) 574 for k in rng: 575 for j in rng: 576 for i in rng: 577 if i!=0 or j!=0 or k!=0: 578 inds.append((i,j,k)) 579 #end if 580 #end for 581 #end for 582 #end for 583 inds = array(inds,dtype=int) 584 opt_tm_wig_indices[nc] = inds 585 else: 586 inds = opt_tm_wig_indices[nc] 587 #end if 588 # track counts of tiling matrices 589 ntilings = len(mats) 590 nequiv_volume = 0 591 nfilter = 0 592 nequiv_inscribe = 0 593 nequiv_wigner = 0 594 nequiv_cubicity = 0 595 nequiv_shape = 0 596 # try a faster search for cells w/ target volume 597 det_inds_p = [ 598 [(0,0),(1,1),(2,2)], 599 [(0,1),(1,2),(2,0)], 600 [(0,2),(1,0),(2,1)] 601 ] 602 det_inds_m = [ 603 [(0,0),(1,2),(2,1)], 604 [(0,1),(1,0),(2,2)], 605 [(0,2),(1,1),(2,0)] 606 ] 607 volfacs = zeros((len(mats),),dtype=int) 608 for (i1,j1),(i2,j2),(i3,j3) in det_inds_p: 609 volfacs += (Tref[i1,j1]+mats[:,i1,j1])*(Tref[i2,j2]+mats[:,i2,j2])*(Tref[i3,j3]+mats[:,i3,j3]) 610 #end for 611 for (i1,j1),(i2,j2),(i3,j3) in det_inds_m: 612 volfacs -= (Tref[i1,j1]+mats[:,i1,j1])*(Tref[i2,j2]+mats[:,i2,j2])*(Tref[i3,j3]+mats[:,i3,j3]) 613 #end for 614 Tmats = mats[abs(volfacs)==volfac] 615 nequiv_volume = len(Tmats) 616 # find the set of cells with maximal inscribing radius 617 inscribe_tilings = [] 618 rmax = -1e99 619 for mat in Tmats: 620 T = Tref + mat 621 if filter(T): 622 nfilter+=1 623 Taxes = dot(T,axes) 624 rc1 = norm(cross(Taxes[0],Taxes[1])) 625 rc2 = norm(cross(Taxes[1],Taxes[2])) 626 rc3 = norm(cross(Taxes[2],Taxes[0])) 627 r = 0.5*volume/max(rc1,rc2,rc3) # inscribing radius 628 if r>rmax or abs(r-rmax)<tol: 629 inscribe_tilings.append((r,T,Taxes)) 630 rmax = r 631 #end if 632 #end if 633 #end for 634 # find the set of cells w/ maximal wigner radius out of the inscribing set 635 wigner_tilings = [] 636 rwmax = -1e99 637 for r,T,Taxes in inscribe_tilings: 638 if abs(r-rmax)<tol: 639 nequiv_inscribe+=1 640 rw = 1e99 641 for ind in inds: 642 rw = min(rw,0.5*norm(dot(ind,Taxes))) 643 #end for 644 if rw>rwmax or abs(rw-rwmax)<tol: 645 wigner_tilings.append((rw,T,Taxes)) 646 rwmax = rw 647 #end if 648 #end if 649 #end for 650 # find the set of cells w/ maximal cubicity 651 # (minimum cube_deviation) 652 cube_tilings = [] 653 cmin = 1e99 654 for rw,T,Ta in wigner_tilings: 655 if abs(rw-rwmax)<tol: 656 nequiv_wigner+=1 657 dc = volume**(1./3)*sqrt(2.) 658 d1 = abs(norm(Ta[0]+Ta[1])-dc) 659 d2 = abs(norm(Ta[1]+Ta[2])-dc) 660 d3 = abs(norm(Ta[2]+Ta[0])-dc) 661 d4 = abs(norm(Ta[0]-Ta[1])-dc) 662 d5 = abs(norm(Ta[1]-Ta[2])-dc) 663 d6 = abs(norm(Ta[2]-Ta[0])-dc) 664 cube_dev = (d1+d2+d3+d4+d5+d6)/(6*dc) 665 if cube_dev<cmin or abs(cube_dev-cmin)<tol: 666 cube_tilings.append((cube_dev,rw,T,Ta)) 667 cmin = cube_dev 668 #end if 669 #end if 670 #end for 671 # prioritize selection by "shapeliness" of tiling matrix 672 # prioritize positive diagonal elements 673 # penalize off-diagonal elements 674 # penalize negative off-diagonal elements 675 shapely_tilings = [] 676 smax = -1e99 677 for cd,rw,T,Taxes in cube_tilings: 678 if abs(cd-cmin)<tol: 679 nequiv_cubicity+=1 680 d = diag(T) 681 o = (T-diag(d)).ravel() 682 s = sign(d).sum()-(abs(o)>0).sum()-(o<0).sum() 683 if s>smax or abs(s-smax)<tol: 684 shapely_tilings.append((s,rw,T,Taxes)) 685 smax = s 686 #end if 687 #end if 688 #end for 689 # prioritize selection by symmetry of tiling matrix 690 ropt = -1e99 691 Topt = None 692 Taxopt = None 693 diagonal = [] 694 symmetric = [] 695 antisymmetric = [] 696 other = [] 697 for s,rw,T,Taxes in shapely_tilings: 698 if abs(s-smax)<tol: 699 nequiv_shape+=1 700 Td = diag(diag(T)) 701 if abs(Td-T).sum()==0: 702 diagonal.append((rw,T,Taxes)) 703 elif abs(T.T-T).sum()==0: 704 symmetric.append((rw,T,Taxes)) 705 elif abs(T.T+T-2*Td).sum()==0: 706 antisymmetric.append((rw,T,Taxes)) 707 else: 708 other.append((rw,T,Taxes)) 709 #end if 710 #end if 711 #end for 712 s = 1 713 if len(diagonal)>0: 714 cells = diagonal 715 elif len(symmetric)>0: 716 cells = symmetric 717 elif len(antisymmetric)>0: 718 cells = antisymmetric 719 s = -1 720 elif len(other)>0: 721 cells = other 722 #end if 723 skew_min = 1e99 724 if len(cells)>0: 725 for rw,T,Taxes in cells: 726 Td = diag(diag(T)) 727 skew = abs(T.T-s*T-(1-s)*Td).sum() 728 if skew<skew_min: 729 ropt = rw 730 Topt = T 731 Taxopt = Taxes 732 skew_min = skew 733 #end if 734 #end for 735 #end if 736 if Taxopt is None: 737 error('optimal tilematrix for volfac={0} not found with tolerance {1}\ndifference range (dn): {2}\ntiling matrices searched: {3}\ncells with target volume: {4}\ncells that passed the filter: {5}\ncells with equivalent inscribing radius: {6}\ncells with equivalent wigner radius: {7}\ncells with equivalent cubicity: {8}\nmatrices with equivalent shapeliness: {9}\nplease try again with dn={10}'.format(volfac,tol,dn,ntilings,nequiv_volume,nfilter,nequiv_inscribe,nequiv_wigner,nequiv_cubicity,nequiv_shape,dn+1)) 738 #end if 739 if det(Taxopt)<0: 740 Topt = -Topt 741 #end if 742 return Topt,ropt 743#end def optimal_tilematrix 744 745 746class Sobj(DevBase): 747 None 748#end class Sobj 749 750 751 752class Structure(Sobj): 753 754 operations = obj() 755 756 @classmethod 757 def set_operations(cls): 758 cls.operations.set( 759 remove_folded_structure = cls.remove_folded_structure, 760 recenter = cls.recenter, 761 ) 762 #end def set_operations 763 764 765 def __init__(self, 766 axes = None, 767 scale = 1., 768 elem = None, 769 pos = None, 770 elem_pos = None, 771 mag = None, 772 center = None, 773 kpoints = None, 774 kweights = None, 775 kgrid = None, 776 kshift = None, 777 permute = None, 778 units = None, 779 tiling = None, 780 rescale = True, 781 dim = 3, 782 magnetization = None, 783 operations = None, 784 background_charge = 0, 785 frozen = None, 786 bconds = None, 787 posu = None, 788 use_prim = None, 789 add_kpath = False, 790 symm_kgrid = False, 791 ): 792 793 if isinstance(axes,str): 794 axes = array(axes.split(),dtype=float) 795 axes.shape = dim,dim 796 #end if 797 if center is None: 798 if axes is not None: 799 center = array(axes,dtype=float).sum(0)/2 800 else: 801 center = dim*[0] 802 #end if 803 #end if 804 if bconds is None or bconds=='periodic': 805 bconds = dim*['p'] 806 #end if 807 if axes is None: 808 axes = [] 809 bconds = [] 810 #end if 811 if elem_pos is not None: 812 ep = array(elem_pos.split(),dtype=str) 813 ep.shape = ep.size//(dim+1),(dim+1) 814 elem = ep[:,0].ravel() 815 pos = ep[:,1:dim+1] 816 #end if 817 if elem is None: 818 elem = [] 819 #end if 820 if posu is not None: 821 pos = posu 822 #end if 823 if pos is None: 824 pos = empty((0,dim)) 825 #end if 826 if kshift is None: 827 kshift = 0,0,0 828 #end if 829 self.scale = 1. 830 self.units = units 831 self.dim = dim 832 self.center = array(center,dtype=float) 833 self.axes = array(axes,dtype=float) 834 self.set_bconds(bconds) 835 self.set_elem(elem) 836 self.set_pos(pos) 837 self.set_mag(mag) 838 self.set_frozen(frozen) 839 self.kpoints = empty((0,dim)) 840 self.kweights = empty((0,)) 841 self.background_charge = background_charge 842 self.remove_folded_structure() 843 if len(axes)==0: 844 self.kaxes=array([]) 845 else: 846 self.kaxes=2*pi*inv(self.axes).T 847 #end if 848 if posu is not None: 849 self.pos_to_cartesian() 850 #end if 851 if magnetization is not None: 852 self.magnetize(magnetization) 853 #end if 854 if use_prim is not None and use_prim is not False: 855 self.become_primitive(source=use_prim,add_kpath=add_kpath) 856 #end if 857 if tiling is not None: 858 self.tile(tiling,in_place=True) 859 #end if 860 if kpoints is not None: 861 self.add_kpoints(kpoints,kweights) 862 #end if 863 if kgrid is not None: 864 if not symm_kgrid: 865 self.add_kmesh(kgrid,kshift) 866 else: 867 self.add_symmetrized_kmesh(kgrid,kshift) 868 #end if 869 #end if 870 if rescale: 871 self.rescale(scale) 872 else: 873 self.scale = scale 874 #end if 875 if permute is not None: 876 self.permute(permute) 877 #end if 878 if operations is not None: 879 self.operate(operations) 880 #end if 881 #end def __init__ 882 883 884 def check_consistent(self,tol=1e-8,exit=True,message=False): 885 msg = '' 886 if self.has_axes(): 887 kaxes = 2*pi*inv(self.axes).T 888 abs_diff = abs(self.kaxes-kaxes).sum() 889 if abs_diff>tol: 890 msg += 'Direct and reciprocal space axes are not consistent.\naxes present:\n{0}\nkaxes present:\n{1}\nConsistent kaxes:\n{2}\nAbsolute difference: {3}\n'.format(self.axes,self.kaxes,kaxes,abs_diff) 891 #end if 892 #end if 893 N = len(self.elem) 894 D = self.dim 895 pshape = (N,D) 896 if self.pos.shape!=pshape: 897 msg += 'pos is not the right shape\npos shape: {}\nCorrect shape: {}\n'.format(self.pos.shape,pshape) 898 #end if 899 if self.mag is not None and len(self.mag)!=N: 900 msg += 'mag does not have the right length\nmag length: {}\nCorrect length: {}\n'.format(self.mag,N) 901 #end if 902 if self.frozen is not None and self.frozen.shape!=pshape: 903 msg += 'frozen is not the right shape\nfrozen shape: {}\nCorrect shape: {}\n'.format(self.frozen.shape,pshape) 904 #end if 905 consistent = len(msg)==0 906 if not consistent and exit: 907 self.error(msg) 908 #end if 909 if not message: 910 return consistent 911 else: 912 return consistent,msg 913 #end if 914 #end def check_consistent 915 916 917 def set_axes(self,axes): 918 self.reset_axes(axes) 919 #end def set_axes 920 921 922 def set_bconds(self,bconds): 923 self.bconds = array(tuple(bconds),dtype=str) 924 #end def bconds 925 926 927 def set_elem(self,elem): 928 self.elem = array(elem,dtype=object) 929 #end def set_elem 930 931 932 def set_pos(self,pos): 933 self.pos = array(pos,dtype=float) 934 if len(self.pos)!=len(self.elem): 935 self.error('Atomic positions must have same length as elem.\nelem length: {}\nAtomic positions length: {}\n'.format(len(self.elem),len(self.pos))) 936 #end if 937 #end def set_pos 938 939 940 def set_mag(self,mag=None): 941 if mag is None: 942 self.mag = None 943 else: 944 self.mag = np.array(mag,dtype=object) 945 if len(self.mag)!=len(self.elem): 946 self.error('Magnetic moments must have same length as elem.\nelem length: {}\nMagnetic moments length: {}\n'.format(len(self.elem),len(self.mag))) 947 #end if 948 #end if 949 #end def set_mag 950 951 952 def set_frozen(self,frozen=None): 953 if frozen is None: 954 self.frozen = None 955 else: 956 self.frozen = np.array(frozen,dtype=bool) 957 if self.frozen.shape!=self.pos.shape: 958 self.error('Frozen directions must have the same shape as positions.\nPositions shape: {0}\nFrozen directions shape: {1}'.format(self.pos.shape,self.frozen.shape)) 959 #end if 960 #end if 961 #end def set_frozen 962 963 964 def size(self): 965 return len(self.elem) 966 #end def size 967 968 969 def has_axes(self): 970 return len(self.axes)==self.dim 971 #end def has_axes 972 973 974 def operate(self,operations): 975 for op in operations: 976 if not op in self.operations: 977 self.error('{0} is not a known operation\nvalid options are:\n {1}'.format(op,list(self.operations.keys()))) 978 else: 979 self.operations[op](self) 980 #end if 981 #end for 982 #end def operate 983 984 985 def has_tmatrix(self): 986 return 'tmatrix' in self and self.tmatrix is not None 987 #end def has_tmatrix 988 989 990 def is_tiled(self): 991 return self.has_folded() and self.has_tmatrix() 992 #end def is_tiled 993 994 995 def set_folded(self,folded): 996 self.set_folded_structure(folded) 997 #end def set_folded 998 999 1000 def remove_folded(self): 1001 self.remove_folded_structure() 1002 #end def remove_folded 1003 1004 1005 def has_folded(self): 1006 return self.has_folded_structure() 1007 #end def has_folded 1008 1009 1010 def set_folded_structure(self,folded): 1011 if not isinstance(folded,Structure): 1012 self.error('cannot set folded structure\nfolded structure must be an object with type Structure\nreceived type: {0}'.format(folded.__class__.__name__)) 1013 #end if 1014 self.folded_structure = folded 1015 if self.has_axes(): 1016 self.tmatrix = self.tilematrix(folded) 1017 #end if 1018 #end def set_folded_structure 1019 1020 1021 def remove_folded_structure(self): 1022 self.folded_structure = None 1023 if 'tmatrix' in self: 1024 del self.tmatrix 1025 #end if 1026 #end def remove_folded_structure 1027 1028 1029 def has_folded_structure(self): 1030 return self.folded_structure is not None 1031 #end def has_folded_structure 1032 1033 1034 # test needed 1035 def group_atoms(self,folded=True): 1036 if len(self.elem)>0: 1037 order = self.elem.argsort() 1038 if (self.elem!=self.elem[order]).any(): 1039 self.elem = self.elem[order] 1040 self.pos = self.pos[order] 1041 #end if 1042 #end if 1043 if self.folded_structure!=None and folded: 1044 self.folded_structure.group_atoms(folded) 1045 #end if 1046 #end def group_atoms 1047 1048 1049 # test needed 1050 def rename(self,folded=True,**name_pairs): 1051 elem = self.elem 1052 for old,new in name_pairs.items(): 1053 for i in range(len(self.elem)): 1054 if old==elem[i]: 1055 elem[i] = new 1056 #end if 1057 #end for 1058 #end for 1059 if self.folded_structure!=None and folded: 1060 self.folded_structure.rename(folded=folded,**name_pairs) 1061 #end if 1062 #end def rename 1063 1064 1065 # test needed 1066 def reset_axes(self,axes=None): 1067 if axes is None: 1068 axes = self.axes 1069 else: 1070 axes = array(axes) 1071 self.remove_folded_structure() 1072 #end if 1073 self.axes = axes 1074 self.kaxes = 2*pi*inv(axes).T 1075 self.center = axes.sum(0)/2 1076 #end def reset_axes 1077 1078 1079 # test needed 1080 def adjust_axes(self,axes): 1081 self.skew(dot(inv(self.axes),axes)) 1082 #end def adjust_axes 1083 1084 1085 # test needed 1086 def reshape_axes(self,reshaping): 1087 R = array(reshaping) 1088 if abs(abs(det(R))-1)<1e-6: 1089 self.axes = dot(self.axes,R) 1090 else: 1091 R = dot(inv(self.axes),R) 1092 if abs(abs(det(R))-1)<1e-6: 1093 self.axes = dot(self.axes,R) 1094 else: 1095 self.error('reshaping matrix must not change the volume\n reshaping matrix:\n {0}\n volume change ratio: {1}'.format(R,abs(det(R)))) 1096 #end if 1097 #end if 1098 #end def reshape_axes 1099 1100 1101 def write_axes(self): 1102 c = '' 1103 for a in self.axes: 1104 c+='{0:12.8f} {1:12.8f} {2:12.8f}\n'.format(a[0],a[1],a[2]) 1105 #end for 1106 return c 1107 #end def write_axes 1108 1109 1110 def corners(self): 1111 a = self.axes 1112 c = array([(0,0,0), 1113 a[0], 1114 a[1], 1115 a[2], 1116 a[0]+a[1], 1117 a[1]+a[2], 1118 a[2]+a[0], 1119 a[0]+a[1]+a[2], 1120 ]) 1121 return c 1122 #end def corners 1123 1124 1125 # test needed 1126 def miller_direction(self,h,k,l,normalize=False): 1127 d = dot((h,k,l),self.axes) 1128 if normalize: 1129 d/=norm(d) 1130 #end if 1131 return d 1132 #end def miller_direction 1133 1134 1135 # test needed 1136 def miller_normal(self,h,k,l,normalize=False): 1137 d = dot((h,k,l),self.kaxes) 1138 if normalize: 1139 d/=norm(d) 1140 #end if 1141 return d 1142 #end def miller_normal 1143 1144 1145 # test needed 1146 def project_plane(self,a1,a2,points=None): 1147 # a1/a2: in plane vectors 1148 if points is None: 1149 points = self.pos 1150 #end if 1151 a1n = norm(a1) 1152 a2n = norm(a2) 1153 a1/=a1n 1154 a2/=a2n 1155 n = cross(a1,a2) 1156 plane_coords = [] 1157 for p in points: 1158 p -= dot(n,p)*n # project point into plane 1159 c1 = dot(a1,p)/a1n 1160 c2 = dot(a2,p)/a2n 1161 plane_coords.append((c1,c2)) 1162 #end for 1163 return array(plane_coords,dtype=float) 1164 #end def project_plane 1165 1166 1167 def bounding_box(self,scale=1.0,minsize=None,mindist=0,box='tight',recenter=False): 1168 pmin = self.pos.min(0)-mindist 1169 pmax = self.pos.max(0)+mindist 1170 pcenter = (pmax+pmin)/2 1171 prange = pmax-pmin 1172 if minsize is not None: 1173 for i,pr in enumerate(prange): 1174 prange[i] = max(minsize,prange[i]) 1175 #end for 1176 #end if 1177 if box=='tight': 1178 axes = diag(prange) 1179 elif box=='cubic' or box=='cube': 1180 prmax = prange.max() 1181 axes = diag((prmax,prmax,prmax)) 1182 elif isinstance(box,ndarray) or isinstance(box,list): 1183 box = array(box) 1184 if box.shape!=(3,3): 1185 self.error('requested box must be 3-dimensional (3x3 axes)\n you provided: '+str(box)+'\n shape: '+str(box.shape)) 1186 #end if 1187 binv = inv(box) 1188 pu = dot(self.pos,binv) 1189 pmin = pu.min(0) 1190 pmax = pu.max(0) 1191 pcenter = (pmax+pmin)/2 1192 prange = pmax-pmin 1193 axes = dot(diag(prange),box) 1194 else: 1195 self.error("invalid request for box\n valid options are 'tight', 'cubic', or axes array (3x3)\n you provided: "+str(box)) 1196 #end if 1197 self.reset_axes(scale*axes) 1198 self.slide(self.center-pcenter,recenter) 1199 #end def bounding_box 1200 1201 1202 def center_molecule(self): 1203 self.slide(self.center-self.pos.mean(0),recenter=False) 1204 #end def center_molecule 1205 1206 1207 # test needed 1208 def center_solid(self): 1209 u = self.pos_unit() 1210 du = (1-u.min(0)-u.max(0))/2 1211 self.slide(dot(du,self.axes),recenter=False) 1212 #end def center_solid 1213 1214 1215 # test needed 1216 def permute(self,permutation): 1217 dim = self.dim 1218 P = empty((dim,dim),dtype=int) 1219 if len(permutation)!=dim: 1220 self.error(' permutation vector must have {0} elements\n you provided {1}'.format(dim,permutation)) 1221 #end if 1222 for i in range(dim): 1223 p = permutation[i] 1224 pv = zeros((dim,),dtype=int) 1225 if p=='x' or p=='0': 1226 pv[0] = 1 1227 elif p=='y' or p=='1': 1228 pv[1] = 1 1229 elif p=='z' or p=='2': 1230 pv[2] = 1 1231 #end if 1232 P[:,i] = pv[:] 1233 #end for 1234 self.center = dot(self.center,P) 1235 if self.has_axes(): 1236 self.axes = dot(self.axes,P) 1237 #end if 1238 if len(self.pos)>0: 1239 self.pos = dot(self.pos,P) 1240 #end if 1241 if len(self.kaxes)>0: 1242 self.kaxes = dot(self.kaxes,P) 1243 #end if 1244 if len(self.kpoints)>0: 1245 self.kpoints = dot(self.kpoints,P) 1246 #end if 1247 if self.folded_structure!=None: 1248 self.folded_structure.permute(permutation) 1249 #end if 1250 #end def permute 1251 1252 1253 # test needed 1254 def rotate_plane(self,plane,angle,units='degrees'): 1255 self.pos = rotate_plane(plane,angle,self.pos,units) 1256 if self.has_axes(): 1257 axes = rotate_plane(plane,angle,self.axes,units) 1258 self.reset_axes(axes) 1259 #end if 1260 #end def rotate_plane 1261 1262 1263 # test needed 1264 def upcast(self,DerivedStructure): 1265 if not issubclass(DerivedStructure,Structure): 1266 self.error(DerivedStructure.__name__,'is not derived from Structure') 1267 #end if 1268 ds = DerivedStructure() 1269 for name,value in self.items(): 1270 ds[name] = deepcopy(value) 1271 #end for 1272 return ds 1273 #end def upcast 1274 1275 1276 # test needed 1277 def incorporate(self,other): 1278 self.set_elem(list(self.elem)+list(other.elem)) 1279 self.pos=array(list(self.pos)+list(other.pos)) 1280 #end def incorporate 1281 1282 1283 # test needed 1284 def clone_from(self,other): 1285 if not isinstance(other,Structure): 1286 self.error('cloning failed\ncan only clone from other Structure objects\nreceived object of type: {0}'.format(other.__class__.__name__)) 1287 #end if 1288 o = other.copy() 1289 self.__dict__ = o.__dict__ 1290 #end def clone_from 1291 1292 1293 # test needed 1294 def add_atoms(self,elem,pos): 1295 self.set_elem(list(self.elem)+list(elem)) 1296 self.pos=array(list(self.pos)+list(pos)) 1297 #end def add_atoms 1298 1299 1300 def is_open(self): 1301 return not self.any_periodic() 1302 #end def is_open 1303 1304 1305 def is_periodic(self): 1306 return self.any_periodic() 1307 #end def is_periodic 1308 1309 1310 def any_periodic(self): 1311 has_cell = self.has_axes() 1312 pbc = False 1313 for bc in self.bconds: 1314 pbc |= bc=='p' 1315 #end if 1316 periodic = has_cell and pbc 1317 return periodic 1318 #end def any_periodic 1319 1320 1321 def all_periodic(self): 1322 has_cell = self.has_axes() 1323 pbc = True 1324 for bc in self.bconds: 1325 pbc &= bc=='p' 1326 #end if 1327 periodic = has_cell and pbc 1328 return periodic 1329 #end def all_periodic 1330 1331 1332 # test needed 1333 def distances(self,pos1=None,pos2=None): 1334 if isinstance(pos1,Structure): 1335 pos1 = pos1.pos 1336 #end if 1337 if pos2 is None: 1338 if pos1 is None: 1339 return sqrt((self.pos**2).sum(1)) 1340 else: 1341 pos2 = self.pos 1342 #end if 1343 #end if 1344 if len(pos1)!=len(pos2): 1345 self.error('positions arrays are not the same length') 1346 #end if 1347 return sqrt(((pos1-pos2)**2).sum(1)) 1348 #end def distances 1349 1350 1351 def count_kshells(self, kcut, tilevec=[12, 12, 12], nkdig=10): 1352 # check tilevec input 1353 for nt in tilevec: 1354 if nt % 2 != 0: 1355 msg = 'tilevec must contain even integers' 1356 msg += ' so that kgrid can be zero centered.' 1357 Structure.class_error(msg, 'count_kshells') 1358 #end if 1359 #end for 1360 1361 origin = np.array([[0, 0, 0]]) 1362 axes = self.axes 1363 raxes = 2*np.pi*np.linalg.inv(axes).T 1364 kvecs = self.tile_points(origin, raxes, tilevec) 1365 kvecs -= np.dot(tilevec, raxes)/2 # center around 0 1366 kmags = np.linalg.norm(kvecs, axis=-1) 1367 1368 # make sure tilevec is sufficient for kcut 1369 klimit = 0.5*kmags.max() 1370 if kcut > klimit: 1371 msg = 'kcut %3.2f > klimit=%3.2f\n' % (kcut, klimit) 1372 msg += ' please increase tilevec to be safe.\n' 1373 Structure.class_error(msg, 'count_kshells') 1374 #end if 1375 1376 sel = (0<kmags) & (kmags<kcut) 1377 ukmags = np.unique(kmags[sel].round(nkdig)) 1378 return len(ukmags) 1379 #end def count_kshells 1380 1381 1382 def volume(self): 1383 if not self.has_axes(): 1384 return None 1385 else: 1386 return abs(det(self.axes)) 1387 #end if 1388 #end def volume 1389 1390 1391 def rwigner(self,nc=5): 1392 if self.dim!=3: 1393 self.error('rwigner is currently only implemented for 3 dimensions') 1394 #end if 1395 rmin = 1e90 1396 n=empty((1,3)) 1397 rng = tuple(range(-nc,nc+1)) 1398 for k in rng: 1399 for j in rng: 1400 for i in rng: 1401 if i!=0 or j!=0 or k!=0: 1402 n[:] = i,j,k 1403 rmin = min(rmin,.5*norm(dot(n,self.axes))) 1404 #end if 1405 #end for 1406 #end for 1407 #end for 1408 return rmin 1409 #end def rwigner 1410 1411 1412 def rinscribe(self): 1413 if self.dim!=3: 1414 self.error('rinscribe is currently only implemented for 3 dimensions') 1415 #end if 1416 radius = 1e99 1417 dim=3 1418 axes=self.axes 1419 for d in range(dim): 1420 i = d 1421 j = (d+1)%dim 1422 rc = cross(axes[i,:],axes[j,:]) 1423 radius = min(radius,.5*abs(det(axes))/norm(rc)) 1424 #end for 1425 return radius 1426 #end def rinscribe 1427 1428 1429 # test needed 1430 def rwigner_cube(self,*args,**kwargs): 1431 cube = Structure() 1432 a = self.volume()**(1./3) 1433 cube.set_axes([[a,0,0],[0,a,0],[0,0,a]]) 1434 return cube.rwigner(*args,**kwargs) 1435 #end def rwigner_cube 1436 1437 1438 # test needed 1439 def rinscribe_cube(self,*args,**kwargs): 1440 cube = Structure() 1441 a = self.volume()**(1./3) 1442 cube.set_axes([[a,0,0],[0,a,0],[0,0,a]]) 1443 return cube.rinscribe(*args,**kwargs) 1444 #end def rinscribe_cube 1445 1446 1447 def rmin(self): 1448 return self.rwigner() 1449 #end def rmin 1450 1451 1452 def rcell(self): 1453 return self.rinscribe() 1454 #end def rcell 1455 1456 1457 # test needed 1458 # scale invariant measure of deviation from cube shape 1459 # based on deviation of face diagonals from cube 1460 def cube_deviation(self): 1461 a = self.axes 1462 dc = self.volume()**(1./3)*sqrt(2.) 1463 d1 = abs(norm(a[0]+a[1])-dc) 1464 d2 = abs(norm(a[1]+a[2])-dc) 1465 d3 = abs(norm(a[2]+a[0])-dc) 1466 d4 = abs(norm(a[0]-a[1])-dc) 1467 d5 = abs(norm(a[1]-a[2])-dc) 1468 d6 = abs(norm(a[2]-a[0])-dc) 1469 return (d1+d2+d3+d4+d5+d6)/(6*dc) 1470 #end def cube_deviation 1471 1472 1473 # test needed 1474 # apply volume preserving shear-removing transformations to cell axes 1475 # resulting unsheared cell has orthogonal axes 1476 # while remaining periodically correct 1477 # note that the unshearing procedure is not unique 1478 # it depends on the order of unshearing operations 1479 def unsheared_axes(self,axes=None,distances=False): 1480 if self.dim!=3: 1481 self.error('unsheared_axes is currently only implemented for 3 dimensions') 1482 #end if 1483 if axes is None: 1484 axes = self.axes 1485 #end if 1486 dim=3 1487 axbar = identity(dim) 1488 axnew = array(axes,dtype=float) 1489 dists = empty((dim,)) 1490 for d in range(dim): 1491 d2 = (d+1)%dim 1492 d3 = (d+2)%dim 1493 n = cross(axnew[d2],axnew[d3]) #vector normal to 2 cell faces 1494 axdist = dot(n,axes[d])/dot(n,axbar[d]) 1495 axnew[d] = axdist*axbar[d] 1496 dists[d] = axdist 1497 #end for 1498 if not distances: 1499 return axnew 1500 else: 1501 return axnew,dists 1502 #end if 1503 #end def unsheared_axes 1504 1505 1506 # test needed 1507 # vectors parallel to cell faces 1508 # length of vectors is distance between parallel face planes 1509 # note that the product of distances is not the cell volume in general 1510 # see "unsheared_axes" function 1511 # (e.g. a volume preserving shear may bring two face planes arbitrarily close) 1512 def face_vectors(self,axes=None,distances=False): 1513 if axes is None: 1514 axes = self.axes 1515 #end if 1516 fv = inv(axes).T 1517 for d in range(len(fv)): 1518 fv[d] /= norm(fv[d]) # face normals 1519 #end for 1520 dv = dot(axes,fv.T) # axis projections onto face normals 1521 fv = dot(dv,fv) # face normals lengthened by plane separation 1522 if not distances: 1523 return fv 1524 else: 1525 return fv,diag(dv) 1526 #end if 1527 #end def face_vectors 1528 1529 1530 # test needed 1531 def face_distances(self): 1532 return self.face_vectors(distances=True)[1] 1533 #end def face_distances 1534 1535 1536 # test needed 1537 def rescale(self,scale): 1538 self.scale *= scale 1539 self.axes *= scale 1540 self.pos *= scale 1541 self.center *= scale 1542 self.kaxes /= scale 1543 self.kpoints/= scale 1544 if self.folded_structure!=None: 1545 self.folded_structure.rescale(scale) 1546 #end if 1547 #end def rescale 1548 1549 1550 # test needed 1551 def stretch(self,s1,s2,s3): 1552 if self.dim!=3: 1553 self.error('stretch is currently only implemented for 3 dimensions') 1554 #end if 1555 d = diag((s1,s2,s3)) 1556 self.skew(d) 1557 #end def stretch 1558 1559 1560 # test needed 1561 def rotate(self,r,rp=None,passive=False,units="radians",check=True): 1562 """ 1563 Arbitrary rotation of the structure. 1564 Parameters 1565 ---------- 1566 r : `array_like, float, shape (3,3)` or `array_like, float, shape (3,)` or `str` 1567 If a 3x3 matrix, then code executes rotation consistent with this matrix -- 1568 it is assumed that the matrix acts on a column-major vector (eg, v'=Rv) 1569 If a three-dimensional array, then the operation of the function depends 1570 on the input type of rp in the following ways: 1571 1. If rp is a scalar, then rp is assumed to be an angle and a rotation 1572 of rp is made about the axis defined by r 1573 2. If rp is a vector, then rp is assumed to be an axis and a rotation is made 1574 such that r aligns with rp 1575 3. If rp is a str, then the rotation is such that r aligns with the 1576 axis given by the str ('x', 'y', 'z', 'a0', 'a1', or 'a2') 1577 If a str then the axis, r, is defined by the input label (e.g. 'x', 'y', 'z', 'a1', 'a2', or 'a3') 1578 and the operation of the function depends on the input type of rp in the following 1579 ways (same as above): 1580 1. If rp is a scalar, then rp is assumed to be an angle and a rotation 1581 of rp is made about the axis defined by r 1582 2. If rp is a vector, then rp is assumed to be an axis and a rotation is made 1583 such that r aligns with rp 1584 3. If rp is a str, then the rotation is such that r aligns with the 1585 axis given by the str ('x', 'y', 'z', 'a0', 'a1', or 'a2') 1586 rp : `array_like, float, shape (3), optional` or `str, optional` 1587 If a 3-dimensional vector is given, then rp is assumed to be an axis and a rotation is made 1588 such that the axis r is aligned with rp. 1589 If a str, then rp is assumed to be an angle and a rotation about the axis defined by r 1590 is made by an angle rp 1591 If a str is given, then rp is assumed to be an axis defined by the given label 1592 (e.g. 'x', 'y', 'z', 'a1', 'a2', or 'a3') and a rotation is made such that the axis r 1593 is aligned with rp. 1594 passive : `bool, optional, default False` 1595 If `True`, perform a passive rotation 1596 If `False`, perform an active rotation 1597 units : `str, optional, default "radians"` 1598 Units of rp, if rp is given as an angle (scalar) 1599 check : `bool, optional, default True` 1600 Perform a check to verify rotation matrix is orthogonal 1601 """ 1602 if rp is not None: 1603 dirmap = dict(x=[1,0,0],y=[0,1,0],z=[0,0,1]) 1604 if isinstance(r,str): 1605 if r[0]=='a': # r= 'a0', 'a1', or 'a2' 1606 r = self.axes[int(r[1])] 1607 else: # r= 'x', 'y', or 'z' 1608 r = dirmap[r] 1609 #end if 1610 else: 1611 r = array(r,dtype=float) 1612 if len(r.shape)>1: 1613 self.error('r must be given as a 1-d vector or string, if rp is not None') 1614 #end if 1615 #end if 1616 if isinstance(rp,(int,float)): 1617 if units=="radians" or units=="rad": 1618 theta = float(rp) 1619 else: 1620 theta = float(rp)*np.pi/180.0 1621 #end if 1622 c = np.cos(theta) 1623 s = np.sin(theta) 1624 else: 1625 if isinstance(rp,str): 1626 if rp[0]=='a': # rp= 'a0', 'a1', or 'a2' 1627 rp = self.axes[int(rp[1])] 1628 else: # rp= 'x', 'y', or 'z' 1629 rp = dirmap[rp] 1630 #end if 1631 else: 1632 rp = array(rp,dtype=float) 1633 #end if 1634 # go from r,rp to r,theta 1635 c = np.dot(r,rp)/np.linalg.norm(r)/np.linalg.norm(rp) 1636 if abs(c-1)<1e-6: 1637 s = 0.0 1638 r = np.array([1,0,0]) 1639 else: 1640 s = np.dot(np.cross(r,rp),np.cross(r,rp))/np.linalg.norm(r)/np.linalg.norm(rp)/np.linalg.norm(np.cross(r,rp)) 1641 r = np.cross(r,rp)/np.linalg.norm(np.cross(r,rp)) 1642 #end if 1643 #end if 1644 # make R from r,theta 1645 R = [[ c+r[0]**2.0*(1.0-c), r[0]*r[1]*(1.0-c)-r[2]*s, r[0]*r[2]*(1.0-c)+r[1]*s], 1646 [r[1]*r[0]*(1.0-c)+r[2]*s, c+r[1]**2.0*(1.0-c), r[1]*r[2]*(1.0-c)-r[0]*s], 1647 [r[2]*r[0]*(1.0-c)-r[1]*s, r[2]*r[1]*(1.0-c)+r[0]*s, c+r[2]**2.0*(1.0-c)]] 1648 else: 1649 R = r 1650 #end if 1651 R = array(R,dtype=float) 1652 if passive: 1653 R = R.T 1654 #end if 1655 if check: 1656 if not np.allclose(dot(R,R.T),identity(len(R))): 1657 self.error('the function, rotate, must be given an orthogonal matrix') 1658 #end if 1659 #end if 1660 self.matrix_transform(R) 1661 #end def rotate 1662 1663 1664 # test needed 1665 def matrix_transform(self,A): 1666 """ 1667 Arbitrary transformation matrix (column-major). 1668 1669 Parameters 1670 ---------- 1671 A : `array_like, float, shape (3,3)` 1672 Transform the structure using the matrix A. It is assumed that 1673 A is in column-major form, i.e., it transforms a vector v as 1674 v' = Av 1675 """ 1676 A = A.T 1677 axinv = inv(self.axes) 1678 axnew = dot(self.axes,A) 1679 kaxinv = inv(self.kaxes) 1680 kaxnew = dot(self.kaxes,inv(A).T) 1681 self.pos = dot(dot(self.pos,axinv),axnew) 1682 self.center = dot(dot(self.center,axinv),axnew) 1683 self.kpoints = dot(dot(self.kpoints,kaxinv),kaxnew) 1684 self.axes = axnew 1685 self.kaxes = kaxnew 1686 if self.folded_structure!=None: 1687 self.folded_structure.matrix_transform(A.T) 1688 #end if 1689 #end def matrix_transform 1690 1691 1692 # test needed 1693 def skew(self,skew): 1694 """ 1695 Arbitrary transformation matrix (row-major). 1696 1697 Parameters 1698 ---------- 1699 skew : `array_like, float, shape (3,3)` 1700 Transform the structure using the matrix skew. It is assumed that 1701 skew is in row-major form, i.e., it transforms a vector v as 1702 v' = vT 1703 """ 1704 self.matrix_transform(skew.T) 1705 #end def skew 1706 1707 1708 # test needed 1709 def change_units(self,units,folded=True): 1710 if units!=self.units: 1711 scale = convert(1,self.units,units) 1712 self.scale *= scale 1713 self.axes *= scale 1714 self.pos *= scale 1715 self.center *= scale 1716 self.kaxes /= scale 1717 self.kpoints/= scale 1718 self.units = units 1719 #end if 1720 if self.folded_structure!=None and folded: 1721 self.folded_structure.change_units(units,folded=folded) 1722 #end if 1723 #end def change_units 1724 1725 1726 # test needed 1727 # insert sep space at loc along axis 1728 # if sep<0, space is removed instead 1729 def cleave(self,axis,loc,sep=None,remove=False,tol=1e-6): 1730 self.remove_folded_structure() 1731 if isinstance(axis,int): 1732 if sep is None: 1733 self.error('separation induced by cleave must be provided') 1734 #end if 1735 v = self.face_vectors()[axis] 1736 if isinstance(loc,float): 1737 c = loc*v/norm(v) 1738 #end if 1739 else: 1740 v = axis 1741 #end if 1742 c = array(c) # point on cleave plane 1743 v = array(v) # normal vector to cleave plane, norm is cleave separation 1744 if sep!=None: 1745 v = abs(sep)*v/norm(v) 1746 #end if 1747 if norm(v)<tol: 1748 return 1749 #end if 1750 vn = array(v/norm(v)) 1751 if sep!=None and sep<0: 1752 v = -v # preserve the normal direction for atom identification, but reverse the shift direction 1753 #end if 1754 self.recorner() # want box contents to be static 1755 if self.has_axes(): 1756 components = 0 1757 dim = self.dim 1758 axes = self.axes 1759 for i in range(dim): 1760 i2 = (i+1)%dim 1761 i3 = (i+2)%dim 1762 a2 = axes[i2]/norm(axes[i2]) 1763 a3 = axes[i3]/norm(axes[i3]) 1764 comp = abs(dot(a2,vn))+abs(dot(a3,vn)) 1765 if comp < 1e-6: 1766 components+=1 1767 iaxis = i 1768 #end if 1769 #end for 1770 commensurate = components==1 1771 if not commensurate: 1772 self.error('cannot insert vacuum because cleave is incommensurate with the cell\n cleave plane must be parallel to a cell face') 1773 #end if 1774 a = self.axes[iaxis] 1775 #self.axes[iaxis] = (1.+dot(v,a)/dot(a,a))*a 1776 self.axes[iaxis] = (1.+dot(v,v)/dot(v,a))*a 1777 #end if 1778 indices = [] 1779 pos = self.pos 1780 for i in range(len(pos)): 1781 p = pos[i] 1782 comp = dot(p-c,vn) 1783 if comp>0 or abs(comp)<tol: 1784 pos[i] += v 1785 indices.append(i) 1786 #end if 1787 #end for 1788 if remove: 1789 self.remove(indices) 1790 #end if 1791 #end def cleave 1792 1793 1794 # test needed 1795 def translate(self,v): 1796 v = array(v) 1797 pos = self.pos 1798 for i in range(len(pos)): 1799 pos[i]+=v 1800 #end for 1801 self.center+=v 1802 if self.folded_structure!=None: 1803 self.folded_structure.translate(v) 1804 #end if 1805 #end def translate 1806 1807 1808 # test needed 1809 def slide(self,v,recenter=True): 1810 v = array(v) 1811 pos = self.pos 1812 for i in range(len(pos)): 1813 pos[i]+=v 1814 #end for 1815 if recenter: 1816 self.recenter() 1817 #end if 1818 if self.folded_structure!=None: 1819 self.folded_structure.slide(v,recenter) 1820 #end if 1821 #end def slide 1822 1823 1824 # test needed 1825 def zero_corner(self): 1826 corner = self.center-self.axes.sum(0)/2 1827 self.translate(-corner) 1828 #end def zero_corner 1829 1830 1831 # test needed 1832 def locate_simple(self,pos): 1833 pos = array(pos) 1834 if pos.shape==(self.dim,): 1835 pos = [pos] 1836 #end if 1837 nn = nearest_neighbors(1,self.pos,pos) 1838 return nn.ravel() 1839 #end def locate_simple 1840 1841 1842 # test needed 1843 def locate(self,identifiers,radii=None,exterior=False): 1844 indices = None 1845 if isinstance(identifiers,Structure): 1846 cell = identifiers 1847 indices = cell.inside(self.pos) 1848 elif isinstance(identifiers,ndarray) and identifiers.dtype==bool: 1849 indices = arange(len(self.pos))[identifiers] 1850 elif isinstance(identifiers,int): 1851 indices = [identifiers] 1852 elif len(identifiers)>0 and isinstance(identifiers[0],int): 1853 indices = identifiers 1854 elif isinstance(identifiers,str): 1855 atom = identifiers 1856 indices = [] 1857 for i in range(len(self.elem)): 1858 if self.elem[i]==atom: 1859 indices.append(i) 1860 #end if 1861 #end for 1862 elif len(identifiers)>0 and isinstance(identifiers[0],str): 1863 indices = [] 1864 for atom in identifiers: 1865 for i in range(len(self.elem)): 1866 if self.elem[i]==atom: 1867 indices.append(i) 1868 #end if 1869 #end for 1870 #end for 1871 #end if 1872 if radii is not None or indices is None: 1873 if indices is None: 1874 pos = identifiers 1875 else: 1876 pos = self.pos[indices] 1877 #end if 1878 if isinstance(radii,float) or isinstance(radii,int): 1879 radii = len(pos)*[radii] 1880 elif radii is not None and len(radii)!=len(pos): 1881 self.error('lengths of input radii and positions do not match\n len(radii)={0}\n len(pos)={1}'.format(len(radii),len(pos))) 1882 #end if 1883 dtable = self.min_image_distances(pos) 1884 indices = [] 1885 if radii is None: 1886 for i in range(len(pos)): 1887 indices.append(dtable[i].argmin()) 1888 #end for 1889 else: 1890 ipos = arange(len(self.pos)) 1891 for i in range(len(pos)): 1892 indices.extend(ipos[dtable[i]<radii[i]]) 1893 #end for 1894 #end if 1895 #end if 1896 if exterior: 1897 indices = list(set(range(len(self.pos)))-set(indices)) 1898 #end if 1899 return indices 1900 #end def locate 1901 1902 1903 def freeze(self,identifiers=None,radii=None,exterior=False,negate=False,directions='xyz'): 1904 if isinstance(identifiers,ndarray) and identifiers.shape==self.pos.shape and identifiers.dtype==bool: 1905 if negate: 1906 self.frozen = ~identifiers 1907 else: 1908 self.frozen = identifiers.copy() 1909 #end if 1910 return 1911 #end if 1912 if identifiers is None: 1913 indices = arange(len(self.pos),dtype=int) 1914 else: 1915 indices = self.locate(identifiers,radii,exterior) 1916 #end if 1917 if len(indices)==0: 1918 self.error('failed to select any atoms to freeze') 1919 #end if 1920 if isinstance(directions,str): 1921 d = empty((3,),dtype=bool) 1922 d[0] = 'x' in directions 1923 d[1] = 'y' in directions 1924 d[2] = 'z' in directions 1925 directions = len(indices)*[d] 1926 else: 1927 directions = array(directions,dtype=bool) 1928 #end if 1929 if self.frozen is None: 1930 self.frozen = zeros(self.pos.shape,dtype=bool) 1931 #end if 1932 frozen = self.frozen 1933 i=0 1934 if not negate: 1935 for index in indices: 1936 frozen[index] = directions[i] 1937 i+=1 1938 #end for 1939 else: 1940 for index in indices: 1941 frozen[index] = directions[i]==False 1942 i+=1 1943 #end for 1944 #end if 1945 #end def freeze 1946 1947 1948 def is_frozen(self): 1949 if self.frozen is None: 1950 return np.zeros((len(self.pos),),dtype=bool) 1951 else: 1952 return self.frozen.sum(1)>0 1953 #end if 1954 #end def is_frozen 1955 1956 1957 # test needed 1958 def magnetize(self,identifiers=None,magnetization='',**mags): 1959 magsin = None 1960 if isinstance(identifiers,obj): 1961 magsin = identifiers.copy() 1962 elif isinstance(magnetization,obj): 1963 magsin = magnetization.copy() 1964 #endif 1965 if magsin!=None: 1966 magsin.transfer_from(mags) 1967 mags = magsin 1968 identifiers = None 1969 magnetization = '' 1970 #end if 1971 for e,m in mags.items(): 1972 if not e in self.elem: 1973 self.error('cannot magnetize non-existent element {0}'.format(e)) 1974 elif m is not None or not isinstance(m,int): 1975 self.error('magnetizations provided must be either None or integer\n you provided: {0}\n full magnetization request provided:\n {1}'.format(m,mags)) 1976 #end if 1977 self.mag[self.elem==e] = m 1978 #end for 1979 if identifiers is None and magnetization=='': 1980 return 1981 elif magnetization=='': 1982 magnetization = identifiers 1983 indices = list(range(len(self.elem))) 1984 else: 1985 indices = self.locate(identifiers) 1986 #end if 1987 if not isinstance(magnetization,(list,tuple,ndarray)): 1988 magnetization = [magnetization] 1989 #end if 1990 for m in magnetization: 1991 if m is not None or not isinstance(m,int): 1992 self.error('magnetizations provided must be either None or integer\n you provided: {0}\n full magnetization list provided: {1}'.format(m,magnetization)) 1993 #end if 1994 #end for 1995 if len(magnetization)==1: 1996 m = magnetization[0] 1997 for i in indices: 1998 self.mag[i] = m 1999 #end for 2000 elif len(magnetization)==len(indices): 2001 for i in range(len(indices)): 2002 self.mag[indices[i]] = magnetization[i] 2003 #end for 2004 else: 2005 self.error('magnetization list and list selected atoms differ in length\n length of magnetization list: {0}\n number of atoms selected: {1}\n magnetization list: {2}\n atom indices selected: {3}\n atoms selected: {4}'.format(len(magnetization),len(indices),magnetization,indices,self.elem[indices])) 2006 #end if 2007 #end def magnetize 2008 2009 2010 def is_magnetic(self,tol=1e-8): 2011 magnetic = False 2012 if self.mag is not None: 2013 for m in self.mag: 2014 if m is not None and abs(m)>tol: 2015 magnetic = True 2016 break 2017 #end if 2018 #end for 2019 #end if 2020 return magnetic 2021 #end def is_magnetic 2022 2023 2024 # test needed 2025 def carve(self,identifiers): 2026 indices = self.locate(identifiers) 2027 if isinstance(identifiers,Structure): 2028 sub = identifiers 2029 sub.elem = self.elem[indices].copy() 2030 sub.pos = self.pos[indices].copy() 2031 else: 2032 sub = self.copy() 2033 sub.elem = self.elem[indices] 2034 sub.pos = self.pos[indices] 2035 #end if 2036 sub.host_indices = array(indices) 2037 return sub 2038 #end def carve 2039 2040 2041 # test needed 2042 def remove(self,identifiers): 2043 indices = self.locate(identifiers) 2044 keep = list(set(range(len(self.pos)))-set(indices)) 2045 erem = self.elem[indices] 2046 prem = self.pos[indices] 2047 self.elem = self.elem[keep] 2048 self.pos = self.pos[keep] 2049 self.remove_folded_structure() 2050 return erem,prem 2051 #end def remove 2052 2053 2054 # test needed 2055 def replace(self,identifiers,elem=None,pos=None,radii=None,exterior=False): 2056 indices = self.locate(identifiers,radii,exterior) 2057 if isinstance(elem,Structure): 2058 cell = elem 2059 elem = cell.elem 2060 pos = cell.pos 2061 elif elem==None: 2062 elem = self.elem 2063 #end if 2064 indices=array(indices) 2065 elem=array(elem,dtype=object) 2066 pos =array(pos) 2067 nrem = len(indices) 2068 nadd = len(pos) 2069 if nadd<nrem: 2070 ar = array(list(range(0,nadd))) 2071 rr = array(list(range(nadd,nrem))) 2072 self.elem[indices[ar]] = elem[:] 2073 self.pos[indices[ar]] = pos[:] 2074 self.remove(indices[rr]) 2075 elif nadd>nrem: 2076 ar = array(list(range(0,nrem))) 2077 er = array(list(range(nrem,nadd))) 2078 self.elem[indices[ar]] = elem[ar] 2079 self.pos[indices[ar]] = pos[ar] 2080 ii = indices[ar[-1]] 2081 self.set_elem( list(self.elem[0:ii])+list(elem[er])+list(self.elem[ii:]) ) 2082 self.pos = array( list(self.pos[0:ii])+list(pos[er])+list(self.pos[ii:]) ) 2083 else: 2084 self.elem[indices] = elem[:] 2085 self.pos[indices] = pos[:] 2086 #end if 2087 self.remove_folded_structure() 2088 #end def replace 2089 2090 2091 # test needed 2092 def replace_nearest(self,elem,pos=None): 2093 if isinstance(elem,Structure): 2094 cell = elem 2095 elem = cell.elem 2096 pos = cell.pos 2097 #end if 2098 nn = nearest_neighbors(1,self.pos,pos) 2099 np = len(pos) 2100 nps= len(self.pos) 2101 d = empty((np,)) 2102 ip = array(list(range(np))) 2103 ips= nn.ravel() 2104 for i in ip: 2105 j = ips[i] 2106 d[i]=sqrt(((pos[i]-self.pos[j])**2).sum()) 2107 #end for 2108 order = d.argsort() 2109 ip = ip[order] 2110 ips=ips[order] 2111 replacable = empty((nps,)) 2112 replacable[:] = False 2113 replacable[ips]=True 2114 insert = [] 2115 last_replaced=nps-1 2116 for n in range(np): 2117 i = ip[n] 2118 j = ips[n] 2119 if replacable[j]: 2120 self.pos[j] = pos[i] 2121 self.elem[j]=elem[i] 2122 replacable[j]=False 2123 last_replaced = j 2124 else: 2125 insert.append(i) 2126 #end if 2127 #end for 2128 insert=array(insert) 2129 ii = last_replaced 2130 if len(insert)>0: 2131 self.set_elem( list(self.elem[0:ii])+list(elem[insert])+list(self.elem[ii:]) ) 2132 self.pos = array( list(self.pos[0:ii])+list(pos[insert])+list(self.pos[ii:]) ) 2133 #end if 2134 self.remove_folded_structure() 2135 #end def replace_nearest 2136 2137 2138 # test needed 2139 def point_defect(self,identifiers=None,elem=None,dr=None): 2140 if isinstance(elem,str): 2141 elem = [elem] 2142 if dr!=None: 2143 dr = [dr] 2144 #end if 2145 #end if 2146 if not 'point_defects' in self: 2147 self.point_defects = obj() 2148 #end if 2149 point_defects = self.point_defects 2150 ncenters = len(point_defects) 2151 if identifiers is None: 2152 index = ncenters 2153 if index>=len(self.pos): 2154 self.error('attempted to add a point defect at index {0}, which does not exist\n for reference there are {1} atoms in the structure'.format(index,len(self.pos))) 2155 #end if 2156 else: 2157 indices = self.locate(identifiers) 2158 if len(indices)>1: 2159 self.error('{0} atoms were located by identifiers provided\n a point defect replaces only a single atom\n atom indices located: {1}'.format(len(indices),indices)) 2160 #end if 2161 index = indices[0] 2162 #end if 2163 if elem is None: 2164 self.error('must supply substitutional elements comprising the point defect\n expected a list or similar for input argument elem') 2165 elif len(elem)>1 and dr is None: 2166 self.error('must supply displacements (dr) since many atoms comprise the point defect') 2167 elif dr!=None and len(elem)!=len(dr): 2168 self.error('elem and dr must have the same length') 2169 #end if 2170 r = self.pos[index] 2171 e = self.elem[index] 2172 elem = array(elem) 2173 pos = zeros((len(elem),len(r))) 2174 if dr is None: 2175 rc = r 2176 for i in range(len(elem)): 2177 pos[i] = r 2178 #end for 2179 else: 2180 nrc = 0 2181 rc = 0*r 2182 dr = array(dr) 2183 for i in range(len(elem)): 2184 pos[i] = r + dr[i] 2185 if norm(dr[i])>1e-5: 2186 rc+=dr[i] 2187 nrc+=1 2188 #end if 2189 #end for 2190 if nrc==0: 2191 rc = r 2192 else: 2193 rc = r + rc/nrc 2194 #end if 2195 #end if 2196 point_defect = obj( 2197 center = rc, 2198 elem_replaced = e, 2199 elem = elem, 2200 pos = pos 2201 ) 2202 point_defects.append(point_defect) 2203 elist = list(self.elem) 2204 plist = list(self.pos) 2205 if len(elem)==0 or len(elem)==1 and elem[0]=='': 2206 elist.pop(index) 2207 plist.pop(index) 2208 else: 2209 elist[index] = elem[0] 2210 plist[index] = pos[0] 2211 for i in range(1,len(elem)): 2212 elist.append(elem[i]) 2213 plist.append(pos[i]) 2214 #end for 2215 #end if 2216 self.set_elem(elist) 2217 self.pos = array(plist) 2218 self.remove_folded_structure() 2219 #end def point_defect 2220 2221 2222 # test needed 2223 def species(self,symbol=False): 2224 if not symbol: 2225 return set(self.elem) 2226 else: 2227 species_labels = set(self.elem) 2228 species = set() 2229 for e in species_labels: 2230 is_elem,symbol = is_element(e,symbol=True) 2231 species.add(symbol) 2232 #end for 2233 return species_labels,species 2234 #end if 2235 #end def species 2236 2237 2238 # test needed 2239 def ordered_species(self,symbol=False): 2240 speclab_set = set() 2241 species_labels = [] 2242 if not symbol: 2243 for e in self.elem: 2244 if e not in speclab_set: 2245 speclab_set.add(e) 2246 species_labels.append(e) 2247 #end if 2248 #end for 2249 return species_labels 2250 else: 2251 species = [] 2252 spec_set = set() 2253 for e in self.elem: 2254 is_elem,symbol = is_element(e,symbol=True) 2255 if e not in speclab_set: 2256 speclab_set.add(e) 2257 species_labels.append(e) 2258 #end if 2259 if symbol not in spec_set: 2260 spec_set.add(symbol) 2261 species.append(symbol) 2262 #end if 2263 #end for 2264 return species_labels,species 2265 #end if 2266 #end def ordered_species 2267 2268 2269 # test needed 2270 def order_by_species(self,folded=False): 2271 species = [] 2272 species_counts = [] 2273 elem_indices = [] 2274 2275 spec_set = set() 2276 for i in range(len(self.elem)): 2277 e = self.elem[i] 2278 if not e in spec_set: 2279 spec_set.add(e) 2280 species.append(e) 2281 species_counts.append(0) 2282 elem_indices.append([]) 2283 #end if 2284 sindex = species.index(e) 2285 species_counts[sindex] += 1 2286 elem_indices[sindex].append(i) 2287 #end for 2288 2289 elem_order = [] 2290 for elem_inds in elem_indices: 2291 elem_order.extend(elem_inds) 2292 #end for 2293 self.reorder(elem_order) 2294 2295 if folded and self.folded_structure!=None: 2296 self.folded_structure.order_by_species(folded) 2297 #end if 2298 2299 return species,species_counts 2300 #end def order_by_species 2301 2302 2303 # test needed 2304 def reorder(self,order): 2305 order = array(order) 2306 self.elem = self.elem[order] 2307 self.pos = self.pos[order] 2308 #end def reorder 2309 2310 2311 # test needed 2312 # find layers parallel to a particular cell face 2313 # layers are found by scanning a window of width dtol along the axis and counting 2314 # the number of atoms within the window. window position w/ max number of atoms 2315 # defines the layer. layer distance is the window position. 2316 # the resolution of the scan is determined by dbin 2317 # (axis length)/dbin is the number of fine bins 2318 # dtol/dbin is the number of fine bins in the moving (boxcar) window 2319 # plot=True: plot the layer histogram (fine hist and moving average) 2320 # composition=True: return the composition of each layer (count of each species) 2321 # returns an object containing indices of atoms in each layer by distance along axis 2322 # example: structure w/ 3 layers of 4 atoms each at distances 3.0, 6.0, and 9.0 Angs. 2323 # layers 2324 # 3.0 = [ 0, 1, 2, 3 ] 2325 # 6.0 = [ 4, 5, 6, 7 ] 2326 # 9.0 = [ 8, 9,10,11 ] 2327 def layers(self,axis=0,dtol=0.03,dbin=0.01,plot=False,composition=False): 2328 nbox = int(dtol/dbin) 2329 if nbox%2==0: 2330 nbox+=1 2331 #end if 2332 nwind = (nbox-1)//2 2333 s = self.copy() 2334 s.recenter() 2335 vaxis = s.axes[axis] 2336 daxis = norm(vaxis) 2337 naxis = vaxis/daxis 2338 dbin = dtol/nbox 2339 nbins = int(ceil(daxis/dbin)) 2340 dbin = daxis/nbins 2341 dbins = daxis*(arange(nbins)+.5)/nbins 2342 dists = daxis*s.pos_unit()[:,axis] 2343 hist = zeros((nbins,),dtype=int) 2344 boxhist = zeros((nbins,),dtype=int) 2345 ihist = obj() 2346 iboxhist = obj() 2347 index = 0 2348 for d in dists: 2349 ibin = int(floor(d/dbin)) 2350 hist[ibin]+=1 2351 if not ibin in ihist: 2352 ihist[ibin] = [] 2353 #end if 2354 ihist[ibin].append(index) 2355 index+=1 2356 #end for 2357 for ib in range(nbins): 2358 for i in range(ib-nwind,ib+nwind+1): 2359 n = hist[i%nbins] 2360 if n>0: 2361 boxhist[ib]+=n 2362 if not ib in iboxhist: 2363 iboxhist[ib] = [] 2364 #end if 2365 iboxhist[ib].extend(ihist[i%nbins]) 2366 #end if 2367 #end for 2368 #end for 2369 peaks = [] 2370 nlast=0 2371 for ib in range(nbins): 2372 n = boxhist[ib] 2373 if nlast==0 and n>0: 2374 pcur = [] 2375 peaks.append(pcur) 2376 #end if 2377 if n>0: 2378 pcur.append(ib) 2379 #end if 2380 nlast = n 2381 #end for 2382 if boxhist[0]>0 and boxhist[-1]>0: 2383 peaks[0].extend(peaks[-1]) 2384 peaks.pop() 2385 #end if 2386 layers = obj() 2387 ip = [] 2388 for peak in peaks: 2389 ib = peak[boxhist[peak].argmax()] 2390 ip.append(ib) 2391 pindices = iboxhist[ib] 2392 ldist = dbins[ib] # distance is along an axis vector 2393 faxis = self.face_vectors()[axis] 2394 ldist = dot(ldist*naxis,faxis/norm(faxis)) 2395 layers[ldist] = array(pindices,dtype=int) 2396 #end for 2397 if plot: 2398 plt.plot(dbins,boxhist,'b.-',label='boxcar histogram') 2399 plt.plot(dbins,hist,'r.-',label='fine histogram') 2400 plt.plot(dbins[ip],boxhist[ip],'rv',markersize=20) 2401 plt.show() 2402 plt.legend() 2403 #end if 2404 if not composition: 2405 return layers 2406 else: 2407 return layers,self.layer_composition(layers) 2408 #end if 2409 #end def layers 2410 2411 2412 # test needed 2413 def layer_composition(self,layers): 2414 lcomp = obj() 2415 for d,ind in layers.items(): 2416 comp = obj() 2417 elem = self.elem[ind] 2418 for e in elem: 2419 if e not in comp: 2420 comp[e] = 1 2421 else: 2422 comp[e] += 1 2423 #end if 2424 #end for 2425 lcomp[d]=comp 2426 #end for 2427 return lcomp 2428 #end def layer_composition 2429 2430 2431 # test needed 2432 def shells(self,identifiers,radii=None,exterior=False,cumshells=False,distances=False,dtol=1e-6): 2433 # get indices for 'core' and 'bulk' 2434 # core is selected by identifiers, forms core for shells to be built around 2435 # bulk is all atoms except for core 2436 if identifiers=='point_defects': 2437 if not 'point_defects' in self: 2438 self.error('requested shells around point defects, but structure has no point defects') 2439 #end if 2440 core = [] 2441 for pd in self.point_defects: 2442 core.append(pd.center) 2443 #end for 2444 core = array(core) 2445 bulk_ind = self.locate(core,radii=dtol,exterior=True) 2446 core_ind = self.locate(bulk_ind,exterior=True) 2447 bulk = self.pos[bulk_ind] 2448 else: 2449 core_ind = self.locate(identifiers,radii,exterior) 2450 bulk_ind = self.locate(core_ind,exterior=True) 2451 core = self.pos[core_ind] 2452 bulk = self.pos[bulk_ind] 2453 #end if 2454 bulk_ind = array(bulk_ind,dtype=int) 2455 # build distance table between bulk and core 2456 dtable = self.distance_table(bulk,core) 2457 # find shortest distance for each bulk atom to any core atom and order by distance 2458 dist = dtable.min(1) 2459 ind = arange(len(bulk)) 2460 order = dist.argsort() 2461 dist = dist[order] 2462 ind = bulk_ind[ind[order]] 2463 # find shells around the core 2464 # the closest atom to the core starts the first shell and defines a shell distance 2465 # other atoms are in the shell if within dtol distance of the first atom 2466 # otherwise a new shell is started 2467 ns = 0 2468 ds = -1 2469 shells = obj() 2470 shells[ns] = list(core_ind) # first shell is all core atoms 2471 dshells = [0.] 2472 for n in range(len(dist)): 2473 if abs(dist[n]-ds)>dtol: 2474 shell = [ind[n]] # new shell starts with single atom 2475 ns+=1 2476 shells[ns] = shell 2477 ds = dist[n] # shell distance is distance of this atom from core 2478 dshells.append(ds) 2479 else: 2480 shell.append(ind[n]) 2481 #end if 2482 #end for 2483 dshells = array(dshells,dtype=float) 2484 results = [shells] 2485 if cumshells: 2486 # assemble cumulative shells, ie cumshell[ns] = sum(shells[n],n=0 to ns) 2487 cumshells = obj() 2488 cumshells[0] = list(shells[0]) 2489 for ns in range(1,len(shells)): 2490 cumshells[ns] = cumshells[ns-1]+shells[ns] 2491 #end for 2492 for ns,cshell in cumshells.items(): 2493 cumshells[ns] = array(cshell,dtype=int) 2494 #end for 2495 results.append(cumshells) 2496 #end if 2497 for ns,shell in shells.items(): 2498 shells[ns] = array(shell,dtype=int) 2499 if distances: 2500 results.append(dshells) 2501 #end if 2502 if len(results)==1: 2503 results = results[0] 2504 #end if 2505 return results 2506 #end def shells 2507 2508 2509 # test needed 2510 # find connected sets of atoms. 2511 # indices is a list of atomic indices to consider (self.pos[indices] are their positions) 2512 # atoms are considered connected if they are within rmax of each other 2513 # order sets the maximum number of atoms in any connected graph 2514 # order = 1 returns single atoms 2515 # order = 2 returns dimers + order=1 results 2516 # order = 3 returns trimers + order=2 results 2517 # ... 2518 # degree is explained w/ an example: a triangle of atoms 0,1,2 and a line of atoms 3,4,5 (3 & 5 are not neighbors) 2519 # degree = False : returned object (cgraphs) has following structure: 2520 # cgraphs[1] = [ (0,), (1,), (2,), (3,), (4,), (5,) ] # first order connected graphs (atoms) 2521 # cgraphs[2] = [ (0,1), (0,2), (1,2), (3,4), (4,5) ] # second order connected graphs (dimers) 2522 # cgraphs[3] = [ (0,1,2), (3,4,5) ] # third order connected graphs (trimers) 2523 # degree = True : returned object (cgraphs) has following structure: 2524 # cgraphs 2525 # 1 # first order connected graphs (atoms) 2526 # 0 # sum of vertex degrees is 0 (a single atom has no neighbors) 2527 # (0,) = [ (0,), (1,), (2,), (3,), (4,), (5,) ] # graphs with vertex degree (0,) 2528 # 2 # second order connected graphs (dimers) 2529 # 2 # sum of vertex degrees is 2 (each atom is connected to 1 neighbor) 2530 # (1,1) = [ (0,1), (0,2), (1,2), (3,4), (4,5) ] # graphs with vertex degree (1,1) 2531 # 3 # third order connected graphs (trimers) 2532 # 4 # sum of vertex degrees is 4 (2 atoms have 1 neighbor and 1 atom has 2) 2533 # (1,1,2) = [ (3,5,4) ] 2534 # 6 # sum of vertex degrees is 6 (each atom is connected to 2 others) 2535 # (2,2,2) = [ (0,1,2) ] # graphs with vertex degree (2,2,2) 2536 def connected_graphs(self,order,indices=None,rmax=None,nmax=None,voronoi=False,degree=False,site_maps=False,**spec_max): 2537 if indices is None: 2538 indices = arange(len(self.pos),dtype=int) 2539 pos = self.pos 2540 else: 2541 pos = self.pos[indices] 2542 #end if 2543 np = len(indices) 2544 neigh_table = [] 2545 actual_indices = None 2546 if voronoi: 2547 actual_indices = True 2548 neighbors = self.voronoi_neighbors(indices,restrict=True,distance_ordered=False) 2549 for nilist in neighbors: 2550 neigh_table.append(nilist) 2551 #end for 2552 else: 2553 actual_indices = False 2554 elem = set(self.elem[indices]) 2555 spec = set(spec_max.keys()) 2556 if spec==elem or rmax!=None: 2557 None 2558 elif spec<elem and nmax!=None: 2559 for e in elem: 2560 if e not in spec: 2561 spec_max[e] = nmax 2562 #end if 2563 #end for 2564 #end if 2565 # get neighbor table for subset of atoms specified by indices 2566 nt,dt = self.neighbor_table(pos,pos,distances=True) 2567 # determine how many neighbors to consider based on rmax (all are neighbors if rmax is None) 2568 nneigh = zeros((np,),dtype=int) 2569 if len(spec_max)>0: 2570 for n in range(np): 2571 nneigh[n] = min(spec_max[self.elem[n]],len(nt[n])) 2572 #end for 2573 elif rmax is None: 2574 nneigh[:] = np 2575 else: 2576 nneigh = (dt<rmax).sum(1) 2577 #end if 2578 for i in range(np): 2579 neigh_table.append(nt[i,1:nneigh[i]]) 2580 #end for 2581 del nt,dt,nneigh,elem,spec,rmax 2582 #end if 2583 neigh_table = array(neigh_table,dtype=int) 2584 # record which atoms are neighbors to each other 2585 neigh_pairs = set() 2586 if actual_indices: 2587 for i in range(np): 2588 for ni in neigh_table[i]: 2589 neigh_pairs.add((i,ni)) 2590 neigh_pairs.add((ni,i)) 2591 #end for 2592 #end for 2593 else: 2594 for i in range(np): 2595 for ni in neigh_table[i]: 2596 ii = indices[i] 2597 jj = indices[ni] 2598 neigh_pairs.add((ii,jj)) 2599 neigh_pairs.add((jj,ii)) 2600 #end for 2601 #end for 2602 #end if 2603 # find the connected graphs 2604 graphs_found = set() # map to contain tuples of connected atom's indices 2605 cgraphs = obj() 2606 for o in range(1,order+1): # organize by order 2607 cgraphs[o] = [] 2608 #end for 2609 if order>0: 2610 cg = cgraphs[1] 2611 for i in range(np): # list of single atoms 2612 gi = (i,) # graph indices 2613 cg.append(gi) # add graph to graph list of order 1 2614 graphs_found.add(gi) # add graph to set of all graphs 2615 #end for 2616 for o in range(2,order+1): # graphs of order o are found by adding all 2617 cglast = cgraphs[o-1] # possible single neighbors to each graph of order o-1 2618 cg = cgraphs[o] 2619 for gilast in cglast: # all graphs of order o-1 2620 for i in gilast: # all indices in each graph of order o-1 2621 for ni in neigh_table[i]: # neighbors of selected atom in o-1 graph 2622 gi = tuple(sorted(gilast+(ni,))) # new graph with neighbor added 2623 if gi not in graphs_found and len(set(gi))==o: # add it if it is new and really is order o 2624 graphs_found.add(gi) # add graph to set of all graphs 2625 cg.append(gi) # add graph to graph list of order o 2626 #end if 2627 #end for 2628 #end for 2629 #end for 2630 #end for 2631 #end if 2632 if actual_indices: 2633 for o,cg in cgraphs.items(): 2634 cgraphs[o] = array(cg,dtype=int) 2635 #end for 2636 else: 2637 # map indices back to actual atomic indices 2638 for o,cg in cgraphs.items(): 2639 cgmap = [] 2640 for gi in cg: 2641 #gi = array(gi) 2642 gimap = tuple(sorted(indices[array(gi)])) 2643 cgmap.append(gimap) 2644 #end for 2645 cgraphs[o] = array(sorted(cgmap),dtype=int) 2646 #end for 2647 #end if 2648 # reorganize the graph listing by cluster and vertex degree, if desired 2649 if degree: 2650 #degree_map = obj() 2651 cgraphs_deg = obj() 2652 for o,cg in cgraphs.items(): 2653 dgo = obj() 2654 cgraphs_deg[o] = dgo 2655 for gi in cg: 2656 di = zeros((o,),dtype=int) 2657 for m in range(o): 2658 i = gi[m] 2659 for n in range(m+1,o): 2660 j = gi[n] 2661 if (i,j) in neigh_pairs: 2662 di[m]+=1 2663 di[n]+=1 2664 #end if 2665 #end for 2666 #end for 2667 d = int(di.sum()) 2668 dorder = di.argsort() 2669 di = tuple(di[dorder]) 2670 gi = tuple(array(gi)[dorder]) 2671 if not d in dgo: 2672 dgo[d]=obj() 2673 #end if 2674 dgd = dgo[d] 2675 if not di in dgd: 2676 dgd[di] = [] 2677 #end if 2678 dgd[di].append(gi) 2679 #degree_map[gi] = d,di 2680 #end for 2681 for dgd in dgo: 2682 for di,dgi in dgd.items(): 2683 dgd[di]=array(sorted(dgi),dtype=int) 2684 #end for 2685 #end for 2686 #end for 2687 cgraphs = cgraphs_deg 2688 #end if 2689 2690 if not site_maps: 2691 return cgraphs 2692 else: 2693 cmaps = obj() 2694 if not degree: 2695 for order,og in cgraphs.items(): 2696 cmap = obj() 2697 for slist in og: 2698 for s in slist: 2699 if not s in cmap: 2700 cmap[s] = obj() 2701 #end if 2702 cmap[s].append(slist) 2703 #end for 2704 #end for 2705 cmaps[order] = cmap 2706 #end for 2707 else: 2708 for order,og in cgraphs.items(): 2709 for total_degree,tg in og.items(): 2710 for local_degree,lg in tg.items(): 2711 cmap = obj() 2712 for slist in lg: 2713 n=0 2714 for s in slist: 2715 d = local_degree[n] 2716 if not s in cmap: 2717 cmap[s] = obj() 2718 #end if 2719 if not d in cmap[s]: 2720 cmap[s][d] = obj() 2721 #end if 2722 cmap[s][d].append(slist) 2723 n+=1 2724 #end for 2725 #end for 2726 cmaps.add_attribute_path((order,total_degree,local_degree),cmap) 2727 #end for 2728 #end for 2729 #end for 2730 #end if 2731 return cgraphs,cmaps 2732 #end if 2733 #end def connected_graphs 2734 2735 2736 # test needed 2737 # returns connected graphs that are rings up to the requested order 2738 # rings are constructed by pairing lines that share endpoints 2739 # all vertices of a ring have degree two 2740 def ring_graphs(self,order,**kwargs): 2741 # get all half order connected graphs 2742 line_order = order//2+order%2+1 2743 cgraphs = self.connected_graphs(line_order,degree=True,site_maps=False,**kwargs) 2744 # collect half order graphs that are lines 2745 lgraphs = obj() 2746 for o in range(2,line_order+1): 2747 total_degree = 2*o-2 2748 vertex_degree = tuple([1,1]+(o-2)*[2]) 2749 lg = None 2750 if o in cgraphs: 2751 cg = cgraphs[o] 2752 if total_degree in cg: 2753 dg = cg[total_degree] 2754 if vertex_degree in dg: 2755 lg = dg[vertex_degree] 2756 #end if 2757 #end if 2758 #end if 2759 if lg!=None: 2760 lg_end = obj() 2761 for gi in lg: 2762 end_key = tuple(sorted(gi[0:2])) # end points 2763 if end_key not in lg_end: 2764 lg_end[end_key] = [] 2765 #end if 2766 lg_end[end_key].append(tuple(gi)) 2767 #end for 2768 lgraphs[o] = lg_end 2769 #end if 2770 #end for 2771 # contruct rings from lines that share endpoints 2772 rgraphs = obj() 2773 for o in range(3,order+1): 2774 o1 = o/2+1 # split half order for odd, same for even, 2775 o2 = o1+o%2 2776 lg1 = lgraphs.get_optional(o1,None) # sets of half order lines 2777 lg2 = lgraphs.get_optional(o2,None) 2778 if lg1!=None and lg2!=None: 2779 rg = [] 2780 rset = set() 2781 for end_key,llist1 in lg1.items(): # list of lines sharing endpoints 2782 if end_key in lg2: 2783 llist2 = lg2[end_key] # second list of lines sharing endpoints 2784 for gi1 in llist1: # combine line pairs into rings 2785 for gi2 in llist2: 2786 ri = tuple(sorted(set(gi1+gi2[2:]))) # ring indices 2787 if ri not in rset and len(ri)==o: # exclude repeated lines or rings 2788 rg.append(ri) 2789 rset.add(ri) 2790 #end if 2791 #end for 2792 #end for 2793 #end if 2794 #end for 2795 rgraphs[o] = array(sorted(rg),dtype=int) 2796 #end if 2797 #end for 2798 return rgraphs 2799 #end def ring_graphs 2800 2801 2802 # test needed 2803 # find the centroid of a set of points/atoms in min image convention 2804 def min_image_centroid(self,points=None,indices=None): 2805 if indices!=None: 2806 points = self.pos[indices] 2807 elif points is None: 2808 self.error('points or images must be provided to min_image_centroid') 2809 #end if 2810 p = array(points,dtype=float) 2811 cprev = p[0]+1e99 2812 c = p[0] 2813 while(norm(c-cprev)>1e-8): 2814 p = self.cell_image(p,center=c) 2815 cprev = c 2816 c = p.mean(axis=0) 2817 #end def min_image_centroid 2818 return c 2819 #end def min_image_centroid 2820 2821 2822 # test needed 2823 # find min image centroids of multiple sets of points/atoms 2824 def min_image_centroids(self,points=None,indices=None): 2825 cents = [] 2826 if points!=None: 2827 for p in points: 2828 cents.append(self.min_image_centroid(p)) 2829 #end for 2830 elif indices!=None: 2831 for ind in indices: 2832 cents.append(self.min_image_centroid(indices=ind)) 2833 #end for 2834 else: 2835 self.error('points or images must be provided to min_image_centroid') 2836 #end if 2837 return array(cents,dtype=float) 2838 #end def min_image_centroids 2839 2840 2841 def min_image_vectors(self,points=None,points2=None,axes=None,pairs=True): 2842 if points is None: 2843 points = self.pos 2844 elif isinstance(points,Structure): 2845 points = points.pos 2846 #end if 2847 if axes is None: 2848 axes = self.axes 2849 #end if 2850 axinv = inv(axes) 2851 points = array(points) 2852 single = points.shape==(self.dim,) 2853 if single: 2854 points = [points] 2855 #end if 2856 if points2 is None: 2857 points2 = self.pos 2858 elif isinstance(points2,Structure): 2859 points2 = points2.pos 2860 elif points2.shape==(self.dim,): 2861 points2 = [points2] 2862 #end if 2863 npoints = len(points) 2864 npoints2 = len(points2) 2865 if pairs: 2866 vtable = empty((npoints,npoints2,self.dim),dtype=float) 2867 i=-1 2868 for p in points: 2869 i+=1 2870 j=-1 2871 for pp in points2: 2872 j+=1 2873 u = dot(pp-p,axinv) 2874 vtable[i,j] = dot(u-floor(u+.5),axes) 2875 #end for 2876 #end for 2877 result = vtable 2878 else: 2879 if npoints!=npoints2: 2880 self.error('cannot create one to one minimum image vectors, point sets differ in length\n npoints1 = {0}\n npoints2 = {1}'.format(npoints,npoints2)) 2881 #end if 2882 vectors = empty((npoints,self.dim),dtype=float) 2883 n = 0 2884 for p in points: 2885 pp = points2[n] 2886 u = dot(pp-p,axinv) 2887 vectors[n] = dot(u-floor(u+.5),axes) 2888 n+=1 2889 #end for 2890 result = vectors 2891 #end if 2892 2893 return result 2894 #end def min_image_vectors 2895 2896 2897 def min_image_distances(self,points=None,points2=None,axes=None,vectors=False,pairs=True): 2898 vtable = self.min_image_vectors(points,points2,axes,pairs=pairs) 2899 rdim = len(vtable.shape)-1 2900 dtable = sqrt((vtable**2).sum(rdim)) 2901 if not vectors: 2902 return dtable 2903 else: 2904 return dtable,vtable 2905 #end if 2906 #end def min_image_distances 2907 2908 2909 def distance_table(self,points=None,points2=None,axes=None,vectors=False): 2910 return self.min_image_distances(points,points2,axes,vectors) 2911 #end def distance_table 2912 2913 2914 def vector_table(self,points=None,points2=None,axes=None): 2915 return self.min_image_vectors(points,points2,axes) 2916 #end def vector_table 2917 2918 2919 def neighbor_table(self,points=None,points2=None,axes=None,distances=False,vectors=False): 2920 dtable,vtable = self.min_image_distances(points,points2,axes,vectors=True) 2921 ntable = empty(dtable.shape,dtype=int) 2922 for i in range(len(dtable)): 2923 ntable[i] = dtable[i].argsort() 2924 #end for 2925 results = [ntable] 2926 if distances: 2927 for i in range(len(dtable)): 2928 dtable[i] = dtable[i][ntable[i]] 2929 #end for 2930 results.append(dtable) 2931 #end if 2932 if vectors: 2933 for i in range(len(vtable)): 2934 vtable[i] = vtable[i][ntable[i]] 2935 #end for 2936 results.append(vtable) 2937 #end if 2938 if len(results)==1: 2939 results = results[0] 2940 #end if 2941 return results 2942 #end def neighbor_table 2943 2944 2945 # test needed 2946 def min_image_norms(self,points,norms): 2947 if isinstance(norms,int) or isinstance(norms,float): 2948 norms = [norms] 2949 #end if 2950 vtable = self.min_image_vectors(points) 2951 rdim = len(vtable.shape)-1 2952 nout = [] 2953 for p in norms: 2954 nout.append( ((abs(vtable)**p).sum(rdim))**(1./p) ) 2955 #end for 2956 if len(norms)==1: 2957 nout = nout[0] 2958 #end if 2959 return nout 2960 #end def min_image_norms 2961 2962 2963 # test needed 2964 # get all neighbors according to contacting voronoi polyhedra in PBC 2965 def voronoi_neighbors(self,indices=None,restrict=False,distance_ordered=True): 2966 if indices is None: 2967 indices = arange(len(self.pos)) 2968 #end if 2969 indices = set(indices) 2970 # make a new version of this (small cell) 2971 sn = self.copy() 2972 sn.recenter() 2973 # tile a large cell periodically 2974 d = 3 2975 t = tuple(zeros((d,),dtype=int)+3) 2976 ss = sn.tile(t) 2977 ss.recenter(sn.center) 2978 # get nearest neighbor index pairs in the large cell 2979 neigh_pairs = voronoi_neighbors(ss.pos) 2980 # create a mapping from large to small indices 2981 large_to_small = 3**d*list(range(len(self.pos))) 2982 # find the neighbor pairs in the small cell 2983 neighbors = obj() 2984 small_inds = set(ss.locate(sn.pos)) 2985 for n in range(len(neigh_pairs)): 2986 i,j = neigh_pairs[n,:] 2987 if i in small_inds or j in small_inds: # pairs w/ at least one in cell image 2988 i = large_to_small[i] # mapping to small cell indices 2989 j = large_to_small[j] 2990 if not restrict or (i in indices and j in indices): # restrict to orig index set 2991 if not i in neighbors: 2992 neighbors[i] = [j] 2993 else: 2994 neighbors[i].append(j) 2995 #ned if 2996 if not j in neighbors: 2997 neighbors[j] = [i] 2998 else: 2999 neighbors[j].append(i) 3000 #end if 3001 #end if 3002 #end if 3003 #end for 3004 # remove any duplicates and order by distance 3005 if distance_ordered: 3006 dt = self.distance_table() 3007 for i,ni in neighbors.items(): 3008 ni = array(list(set(ni)),dtype=int) 3009 di = dt[i,ni] 3010 order = di.argsort() 3011 neighbors[i] = ni[order] 3012 #end for 3013 else: # just remove duplicates 3014 for i,ni in neighbors.items(): 3015 neighbors[i] = array(list(set(ni)),dtype=int) 3016 #end for 3017 #end if 3018 return neighbors 3019 #end def voronoi_neighbors 3020 3021 3022 def voronoi_vectors(self,indices=None,restrict=None): 3023 ni = self.voronoi_neighbors(indices,restrict) 3024 vt = self.vector_table() 3025 vv = obj() 3026 for i,vi in ni.items(): 3027 vv[i] = vt[i,vi] 3028 #end for 3029 return vv 3030 #end def voronoi_vectors 3031 3032 3033 def voronoi_distances(self,indices=None,restrict=False): 3034 vv = self.voronoi_vectors(indices,restrict) 3035 vd = obj() 3036 for i,vvi in vv.items(): 3037 vd[i] = np.linalg.norm(vvi,axis=1) 3038 #end for 3039 return vd 3040 #end def voronoi_distances 3041 3042 3043 def voronoi_radii(self,indices=None,restrict=None): 3044 vd = self.voronoi_distances(indices,restrict) 3045 vr = obj() 3046 for i,vdi in vd.items(): 3047 vr[i] = vdi.min()/2 3048 #end for 3049 return vr 3050 #end def voronoi_radii 3051 3052 3053 def voronoi_species_radii(self): 3054 vr = self.voronoi_radii() 3055 vsr = obj() 3056 for i,r in vr.items(): 3057 e = self.elem[i] 3058 if e not in vsr: 3059 vsr[e] = r 3060 else: 3061 vsr[e] = min(vsr[e],r) 3062 #end if 3063 #end for 3064 return vsr 3065 #end def voronoi_species_radii 3066 3067 3068 # test needed 3069 # get nearest neighbors according to constrants (voronoi, max distance, coord. number) 3070 def nearest_neighbors(self,indices=None,rmax=None,nmax=None,restrict=False,voronoi=False,distances=False,**spec_max): 3071 if indices is None: 3072 indices = arange(len(self.pos)) 3073 #end if 3074 elem = set(self.elem[indices]) 3075 spec = set(spec_max.keys()) 3076 if spec==elem or rmax!=None or voronoi: 3077 None 3078 elif spec<elem and nmax!=None: 3079 for e in elem: 3080 if e not in spec: 3081 spec_max[e] = nmax 3082 #end if 3083 #end for 3084 else: 3085 self.error('must specify nmax for all species\n species present: {0}\n you only provided nmax for these species: {1}'.format(sorted(elem),sorted(spec))) 3086 #end if 3087 pos = self.pos[indices] 3088 if not restrict: 3089 pos2 = self.pos 3090 else: 3091 pos2 = pos 3092 #end if 3093 if voronoi: 3094 neighbors = self.voronoi_neighbors(indices=indices,restrict=restrict) 3095 dt = self.distance_table(pos,pos2)[:,1:] 3096 else: 3097 nt,dt = self.neighbor_table(pos,pos2,distances=True) 3098 dt=dt[:,1:] 3099 nt=nt[:,1:] 3100 neighbors = list(nt) 3101 #end if 3102 for i in range(len(indices)): 3103 neighbors[i] = indices[neighbors[i]] 3104 #end for 3105 dist = list(dt) 3106 if rmax is None: 3107 for i in range(len(indices)): 3108 nn = neighbors[i] 3109 dn = dist[i] 3110 e = self.elem[indices[i]] 3111 if e in spec_max: 3112 smax = spec_max[e] 3113 if len(nn)>smax: 3114 neighbors[i] = nn[:smax] 3115 dist[i] = dn[:smax] 3116 #end if 3117 #end if 3118 #end for 3119 else: 3120 for i in range(len(indices)): 3121 neighbors[i] = neighbors[i][dt[i]<rmax] 3122 #end for 3123 #end if 3124 if not distances: 3125 return neighbors 3126 else: 3127 return neighbors,dist 3128 #end if 3129 #end def nearest_neighbors 3130 3131 3132 # test needed 3133 # determine local chemical coordination limited by constraints 3134 def chemical_coordination(self,indices=None,nmax=None,rmax=None,restrict=False,voronoi=False,neighbors=False,distances=False,**spec_max): 3135 if indices is None: 3136 indices = arange(len(self.pos)) 3137 #end if 3138 if not distances: 3139 neigh = self.nearest_neighbors(indices=indices,nmax=nmax,rmax=rmax,restrict=restrict,voronoi=voronoi,**spec_max) 3140 else: 3141 neigh,dist = self.nearest_neighbors(indices=indices,nmax=nmax,rmax=rmax,restrict=restrict,voronoi=voronoi,distances=True,**spec_max) 3142 #end if 3143 neigh_elem = [] 3144 for i in range(len(indices)): 3145 neigh_elem.extend(self.elem[neigh[i]]) 3146 #end for 3147 chem_key = tuple(sorted(set(neigh_elem))) 3148 chem_coord = zeros((len(indices),len(chem_key)),dtype=int) 3149 for i in range(len(indices)): 3150 counts = zeros((len(chem_key),),dtype=int) 3151 nn = list(self.elem[neigh[i]]) 3152 for n in range(len(counts)): 3153 chem_coord[i,n] = nn.count(chem_key[n]) 3154 #end for 3155 #end for 3156 chem_map = obj() 3157 i=0 3158 for coord in chem_coord: 3159 coord = tuple(coord) 3160 if not coord in chem_map: 3161 chem_map[coord] = [indices[i]] 3162 else: 3163 chem_map[coord].append(indices[i]) 3164 #end if 3165 i+=1 3166 #end for 3167 for coord,ind in chem_map.items(): 3168 chem_map[coord] = array(ind,dtype=int) 3169 #end for 3170 results = [chem_key,chem_coord,chem_map] 3171 if neighbors: 3172 results.append(neigh) 3173 #end if 3174 if distances: 3175 results.append(dist) 3176 #end if 3177 return results 3178 #end def chemical_coordination 3179 3180 3181 # test needed 3182 def rcore_max(self,units=None): 3183 nt,dt = self.neighbor_table(self.pos,distances=True) 3184 d = dt[:,1] 3185 rcm = d.min()/2 3186 if units!=None: 3187 rcm = convert(rcm,self.units,units) 3188 #end if 3189 return rcm 3190 #end def rcore_max 3191 3192 3193 # test needed 3194 def cell_image(self,p,center=None): 3195 pos = array(p,dtype=float) 3196 if center is None: 3197 c = self.center.copy() 3198 else: 3199 c = array(center,dtype=float) 3200 #end if 3201 axes = self.axes 3202 axinv = inv(axes) 3203 for i in range(len(pos)): 3204 u = dot(pos[i]-c,axinv) 3205 pos[i] = dot(u-floor(u+.5),axes)+c 3206 #end for 3207 return pos 3208 #end def cell_image 3209 3210 3211 # test needed 3212 def center_distances(self,points,center=None): 3213 if center is None: 3214 c = self.center.copy() 3215 else: 3216 c = array(center,dtype=float) 3217 #end if 3218 points = self.cell_image(points,center=c) 3219 for i in range(len(points)): 3220 points[i] -= c 3221 #end for 3222 return sqrt((points**2).sum(1)) 3223 #end def center_distances 3224 3225 3226 # test needed 3227 def recenter(self,center=None): 3228 if center is not None: 3229 self.center=array(center,dtype=float) 3230 #end if 3231 pos = self.pos 3232 c = empty((1,self.dim),dtype=float) 3233 c[:] = self.center[:] 3234 axes = self.axes 3235 axinv = inv(axes) 3236 for i in range(len(pos)): 3237 u = dot(pos[i]-c,axinv) 3238 pos[i] = dot(u-floor(u+.5),axes)+c 3239 #end for 3240 self.recenter_k() 3241 #end def recenter 3242 3243 3244 # test needed 3245 def recorner(self): 3246 pos = self.pos 3247 axes = self.axes 3248 axinv = inv(axes) 3249 for i in range(len(pos)): 3250 u = dot(pos[i],axinv) 3251 pos[i] = dot(u-floor(u),axes) 3252 #end for 3253 #end def recorner 3254 3255 3256 # test needed 3257 def recenter_k(self,kpoints=None,kaxes=None,kcenter=None,remove_duplicates=False): 3258 use_self = kpoints is None 3259 if use_self: 3260 kpoints=self.kpoints 3261 #end if 3262 if kaxes is None: 3263 kaxes=self.kaxes 3264 #end if 3265 if len(kpoints)>0: 3266 axes = kaxes 3267 axinv = inv(axes) 3268 if kcenter is None: 3269 c = axes.sum(0)/2 3270 else: 3271 c = array(kcenter) 3272 #end if 3273 for i in range(len(kpoints)): 3274 u = dot(kpoints[i]-c,axinv) 3275 u -= floor(u+.5) 3276 u[abs(u-.5)<1e-12] -= 1.0 3277 u[abs(u )<1e-12] = 0.0 3278 kpoints[i] = dot(u,axes)+c 3279 #end for 3280 if remove_duplicates: 3281 inside = self.inside(kpoints,axes,c) 3282 kpoints = kpoints[inside] 3283 nkpoints = len(kpoints) 3284 unique = empty((nkpoints,),dtype=bool) 3285 unique[:] = True 3286 nn = nearest_neighbors(1,kpoints) 3287 if nkpoints>1: 3288 nn.shape = nkpoints, 3289 dist = self.distances(kpoints,kpoints[nn]) 3290 tol = 1e-8 3291 duplicates = arange(nkpoints)[dist<tol] 3292 for i in duplicates: 3293 if unique[i]: 3294 for j in duplicates: 3295 if sqrt(((kpoints[i]-kpoints[j])**2).sum(1))<tol: 3296 unique[j] = False 3297 #end if 3298 #end for 3299 #end if 3300 #end for 3301 #end if 3302 kpoints = kpoints[unique] 3303 #end if 3304 #end if 3305 if use_self: 3306 self.kpoints = kpoints 3307 else: 3308 return kpoints 3309 #end if 3310 #end def recenter_k 3311 3312 3313 # test needed 3314 def inside(self,pos,axes=None,center=None,tol=1e-8,separate=False): 3315 if axes==None: 3316 axes=self.axes 3317 #end if 3318 if center==None: 3319 center=self.center 3320 #end if 3321 axes = array(axes) 3322 center = array(center) 3323 inside = [] 3324 surface = [] 3325 su = [] 3326 axinv = inv(axes) 3327 for i in range(len(pos)): 3328 u = dot(pos[i]-center,axinv) 3329 umax = abs(u).max() 3330 if abs(umax-.5)<tol: 3331 surface.append(i) 3332 su.append(u) 3333 elif umax<.5: 3334 inside.append(i) 3335 #end if 3336 #end for 3337 npos,dim = pos.shape 3338 drange = list(range(dim)) 3339 n = len(surface) 3340 i=0 3341 while i<n: 3342 j=i+1 3343 while j<n: 3344 du = abs(su[i]-su[j]) 3345 match = False 3346 for d in drange: 3347 match = match or abs(du[d]-1.)<tol 3348 #end for 3349 if match: 3350 surface[j]=surface[-1] 3351 surface.pop() 3352 su[j]=su[-1] 3353 su.pop() 3354 n-=1 3355 else: 3356 j+=1 3357 #end if 3358 #end while 3359 i+=1 3360 #end while 3361 if not separate: 3362 inside+=surface 3363 return inside 3364 else: 3365 return inside,surface 3366 #end if 3367 #end def inside 3368 3369 3370 def tile(self,*td,**kwargs): 3371 in_place = kwargs.pop('in_place',False) 3372 check = kwargs.pop('check',False) 3373 3374 dim = self.dim 3375 if len(td)==1: 3376 if isinstance(td[0],int): 3377 tiling = dim*[td[0]] 3378 else: 3379 tiling = td[0] 3380 #end if 3381 else: 3382 tiling = td 3383 #end if 3384 tiling = array(tiling) 3385 3386 matrix_tiling = tiling.shape == (dim,dim) 3387 3388 tilematrix,tilevector = reduce_tilematrix(tiling) 3389 3390 ncells = int(round( abs(det(tilematrix)) )) 3391 3392 if ncells==1 and abs(tilematrix-identity(self.dim)).sum()<1e-1: 3393 if in_place: 3394 return self 3395 else: 3396 return self.copy() 3397 #end if 3398 #end if 3399 3400 self.recenter() 3401 3402 elem = array(ncells*list(self.elem)) 3403 pos = self.tile_points(self.pos,self.axes,tilematrix,tilevector) 3404 axes = dot(tilematrix,self.axes) 3405 3406 center = axes.sum(0)/2 3407 kaxes = dot(inv(tilematrix.T),self.kaxes) 3408 kpoints = array(self.kpoints) 3409 kweights = array(self.kweights) 3410 mag = None 3411 frozen = None 3412 if self.mag is not None: 3413 mag = ncells*list(self.mag) 3414 #end if 3415 if self.frozen is not None: 3416 frozen = ncells*list(self.frozen) 3417 #end if 3418 3419 ts = self.copy() 3420 ts.center = center 3421 ts.set_elem(elem) 3422 ts.axes = axes 3423 ts.pos = pos 3424 ts.mag = mag 3425 ts.kaxes = kaxes 3426 ts.kpoints = kpoints 3427 ts.kweights = kweights 3428 ts.set_mag(mag) 3429 ts.set_frozen(frozen) 3430 ts.background_charge = ncells*self.background_charge 3431 3432 ts.recenter() 3433 ts.unique_kpoints() 3434 if self.is_tiled(): 3435 ts.tmatrix = dot(tilematrix,self.tmatrix) 3436 ts.folded_structure = self.folded_structure.copy() 3437 else: 3438 ts.tmatrix = tilematrix 3439 ts.folded_structure = self.copy() 3440 #end if 3441 3442 if in_place: 3443 self.clear() 3444 self.transfer_from(ts) 3445 ts = self 3446 #end if 3447 3448 if check: 3449 ts.check_tiling() 3450 #end if 3451 3452 return ts 3453 #end def tile 3454 3455 3456 def tile_points(self,points,axes,tilemat,tilevec=None): 3457 if tilevec is None: 3458 tilemat,tilevec = reduce_tilematrix(tilemat) 3459 #end if 3460 if not isinstance(tilemat,ndarray): 3461 tilemat = array(tilemat) 3462 #end if 3463 matrix_tiling = abs(tilemat-diag(diag(tilemat))).sum()>0.1 3464 if not matrix_tiling: 3465 return self.tile_points_simple(points,axes,diag(abs(tilemat))) 3466 else: 3467 if not isinstance(axes,ndarray): 3468 axes = array(axes) 3469 #end if 3470 if not isinstance(tilevec,ndarray): 3471 tilevec = array(tilevec) 3472 #end if 3473 dim = len(axes) 3474 npoints = len(points) 3475 ntpoints = npoints*int(round(abs(det(tilemat)))) 3476 if tilevec.size==dim: 3477 tilevec.shape = 1,dim 3478 #end if 3479 taxes = dot(tilemat,axes) 3480 success = False 3481 for tvec in tilevec: 3482 tpoints = self.tile_points_simple(points,axes,tvec) 3483 tpoints,weights,pmap = self.unique_points_fast(tpoints,taxes) 3484 if len(tpoints)==ntpoints: 3485 success = True 3486 break 3487 #end if 3488 #end for 3489 if not success: 3490 tpoints = self.tile_points_brute(points,axes,tilemat) 3491 tpoints,weights,pmap = self.unique_points_fast(tpoints,taxes) 3492 if len(tpoints)!=ntpoints: 3493 self.error('brute force tiling failed') 3494 #end if 3495 #end if 3496 #end if 3497 return tpoints 3498 #end def tile_points 3499 3500 3501 def tile_points_simple(self,points,axes,tilevec): 3502 if not isinstance(points,ndarray): 3503 points = array(points) 3504 #end if 3505 if not isinstance(tilevec,ndarray): 3506 tilevec = array(tilevec) 3507 #end if 3508 if not isinstance(axes,ndarray): 3509 axes = array(axes) 3510 #end if 3511 if len(points.shape)==1: 3512 npoints,dim = len(points),1 3513 else: 3514 npoints,dim = points.shape 3515 #end if 3516 t = tilevec 3517 ti = array(around(t),dtype=int) 3518 noninteger = abs(t-ti).sum()>1e-6 3519 if noninteger: 3520 tp = t.prod() 3521 if abs(tp-int(tp))>1e-6: 3522 self.error('tiling vector does not correspond to an integer volume change\ntiling vector: {0}\nvolume change: {1} {2} {3}'.format(tilevec,tilevec.prod(),ntpoints,int(ntpoints))) 3523 #end if 3524 t = array(ceil(t),dtype=int)+1 3525 else: 3526 t = ti 3527 #end if 3528 if t.min()<0: 3529 self.error('tiling vector cannot be negative\ntiling vector provided: {}'.format(t)) 3530 #end if 3531 ntpoints = npoints*int(round( t.prod() )) 3532 if ntpoints==0: 3533 tpoints = array([]) 3534 else: 3535 tpoints = empty((ntpoints,dim)) 3536 ns=0 3537 ne=npoints 3538 for k in range(t[2]): 3539 for j in range(t[1]): 3540 for i in range(t[0]): 3541 v = dot(array([[i,j,k]]),axes) 3542 for d in range(dim): 3543 tpoints[ns:ne,d] = points[:,d]+v[0,d] 3544 #end for 3545 ns+=npoints 3546 ne+=npoints 3547 #end for 3548 #end for 3549 #end for 3550 #end if 3551 return tpoints 3552 #end def tile_points_simple 3553 3554 3555 def tile_points_brute(self,points,axes,tilemat): 3556 tcorners = [[0,0,0], 3557 [1,0,0], 3558 [0,1,0], 3559 [0,0,1], 3560 [0,1,1], 3561 [1,0,1], 3562 [1,1,0], 3563 [1,1,1]] 3564 tcorners = dot(tcorners,tilemat) 3565 tmin = tcorners.min(axis=0) 3566 tmax = tcorners.max(axis=0) 3567 tilevec = tmax-tmin 3568 tpoints = self.tile_points_simple(points,axes,tilevec) 3569 return tpoints 3570 #end def tile_points_brute 3571 3572 3573 3574 def opt_tilematrix(self,*args,**kwargs): 3575 return optimal_tilematrix(self.axes,*args,**kwargs) 3576 #end def opt_tilematrix 3577 3578 3579 def tile_opt(self,*args,**kwargs): 3580 Topt,ropt = self.opt_tilematrix(*args,**kwargs) 3581 return self.tile(Topt) 3582 #end def tile_opt 3583 3584 3585 def check_tiling(self,tol=1e-6,exit=True): 3586 msg = '' 3587 if not self.is_tiled(): 3588 return msg 3589 #end if 3590 msgs = [] 3591 st = self 3592 s = self.folded_structure 3593 nt = len(st.pos) 3594 n = len(s.pos) 3595 if nt%n!=0: 3596 msgs.append('tiled atom count does is not divisible by untiled atom count') 3597 #end if 3598 vratio = st.volume()/s.volume() 3599 if abs(vratio-float(nt)/n)>tol: 3600 msgs.append('tiled/untiled volume ratio does not match tiled/untiled atom count ratio') 3601 #end if 3602 if abs(vratio-abs(det(st.tmatrix)))>tol: 3603 msgs.append('tiled/untiled volume ratio does not match tiling matrix determinant') 3604 #end if 3605 p,w,pmap = self.unique_points_fast(st.pos,st.axes) 3606 if len(p)!=nt: 3607 msgs.append('tiled positions are not unique') 3608 #end if 3609 if len(msgs)>0: 3610 msg = 'tiling check failed' 3611 for m in msgs: 3612 msg += '\n'+m 3613 #end for 3614 if exit: 3615 self.error(msg) 3616 #end if 3617 #end if 3618 return msg 3619 #end def check_tiling 3620 3621 3622 # test needed 3623 def kfold(self,tiling,kpoints,kweights): 3624 if isinstance(tiling,int): 3625 tiling = self.dim*[tiling] 3626 #end if 3627 tiling = array(tiling) 3628 if tiling.shape==(self.dim,self.dim): 3629 tiling = tiling.T 3630 #end if 3631 tilematrix,tilevector = reduce_tilematrix(tiling) 3632 ncells = int(round( abs(det(tilematrix)) )) 3633 kp = self.tile_points(kpoints,self.kaxes,tilematrix,tilevector) 3634 kw = array(ncells*list(kweights),dtype=float)/ncells 3635 return kp,kw 3636 #end def kfold 3637 3638 3639 def get_smallest(self): 3640 if self.has_folded(): 3641 return self.folded_structure 3642 else: 3643 return self 3644 #end if 3645 #end def get_smallest 3646 3647 3648 # test needed 3649 def fold(self,small,*requests): 3650 self.error('fold needs a developers attention to make it equivalent with tile') 3651 if self.dim!=3: 3652 self.error('fold is currently only implemented for 3 dimensions') 3653 #end if 3654 self.recenter_k() 3655 corners = [] 3656 ndim = len(small.axes) 3657 imin = empty((ndim,),dtype=int) 3658 imax = empty((ndim,),dtype=int) 3659 imin[:] = 1000000 3660 imax[:] = -1000000 3661 axinv = inv(self.kaxes) 3662 center = self.kaxes.sum(0)/2 3663 c = empty((1,3)) 3664 for k in -1,2: 3665 for j in -1,2: 3666 for i in -1,2: 3667 c[:] = i,j,k 3668 c = dot(c,small.kaxes) 3669 u = dot(c-center,axinv) 3670 for d in range(ndim): 3671 imin[d] = min(int(floor(u[0,d])),imin[d]) 3672 imax[d] = max(int(ceil(u[0,d])),imax[d]) 3673 #end for 3674 #end for 3675 #end for 3676 #end for 3677 3678 axes = small.kaxes 3679 axinv = inv(small.kaxes) 3680 3681 center = small.kaxes.sum(0)/2 3682 nkpoints = len(self.kpoints) 3683 kindices = [] 3684 kpoints = [] 3685 shift = empty((ndim,)) 3686 kr = list(range(nkpoints)) 3687 for k in range(imin[2],imax[2]+1): 3688 for j in range(imin[1],imax[1]+1): 3689 for i in range(imin[0],imax[0]+1): 3690 for n in kr: 3691 shift[:] = i,j,k 3692 shift = dot(shift,self.kaxes) 3693 kp = self.kpoints[n]+shift 3694 u = dot(kp-center,axinv) 3695 if abs(u).max()<.5+1e-10: 3696 kindices.append(n) 3697 kpoints.append(kp) 3698 #end if 3699 #end for 3700 #end for 3701 #end for 3702 #end for 3703 kindices = array(kindices) 3704 kpoints = array(kpoints) 3705 inside = self.inside(kpoints,axes,center) 3706 kindices = kindices[inside] 3707 kpoints = kpoints[inside] 3708 3709 small.kpoints = kpoints 3710 small.recenter_k() 3711 kpoints = array(small.kpoints) 3712 if len(requests)>0: 3713 results = [] 3714 for request in requests: 3715 if request=='kmap': 3716 kmap = obj() 3717 for k in self.kpoints: 3718 kmap[tuple(k)] = [] 3719 #end for 3720 for i in range(len(kpoints)): 3721 kp = tuple(self.kpoints[kindices[i]]) 3722 kmap[kp].append(array(kpoints[i])) 3723 #end for 3724 for kl,ks in kmap.items(): 3725 kmap[kl] = array(ks) 3726 #end for 3727 res = kmap 3728 elif request=='tilematrix': 3729 res = self.tilematrix(small) 3730 else: 3731 self.error(request+' is not a recognized input to fold') 3732 #end if 3733 results.append(res) 3734 #end if 3735 return results 3736 #end if 3737 #end def fold 3738 3739 3740 def tilematrix(self,small=None,tol=1e-6,status=False): 3741 if small is None: 3742 if self.folded_structure is not None: 3743 small = self.folded_structure 3744 else: 3745 return identity(self.dim,dtype=int) 3746 #end if 3747 #end if 3748 tm = dot(self.axes,inv(small.axes)) 3749 tilemat = array(around(tm),dtype=int) 3750 error = abs(tilemat-tm).sum() 3751 non_integer_elements = error > tol 3752 if status: 3753 return tilemat,not non_integer_elements 3754 else: 3755 if non_integer_elements: 3756 self.error('large cell cannot be constructed as an integer tiling of the small cell\nlarge cell axes:\n'+str(self.axes)+'\nsmall cell axes: \n'+str(small.axes)+'\nlarge/small:\n'+str(self.axes/small.axes)+'\ntiling matrix:\n'+str(tm)+'\nintegerized tiling matrix:\n'+str(tilemat)+'\nerror: '+str(error)+'\ntolerance: '+str(tol)) 3757 #end if 3758 return tilemat 3759 #end if 3760 #end def tilematrix 3761 3762 3763 def primitive(self,source=None,tmatrix=False,add_kpath=False,**kwargs): 3764 res = None 3765 allowed_sources = set(['seekpath']) 3766 if source is None or isinstance(source,bool): 3767 source = 'seekpath' 3768 #end if 3769 if source not in allowed_sources: 3770 self.error('source used to obtain primitive cell is unrecognized\nsource requested: {0}\nallowed sources: {1}'.format(source,sorted(allowed_sources))) 3771 #end if 3772 if source=='seekpath': 3773 res_skp = get_seekpath_full(structure=self,primitive=True,**kwargs) 3774 prim = res_skp.primitive 3775 T = res_skp.prim_tmatrix 3776 if add_kpath: 3777 prim.add_kpoints(res_skp.explicit_kpoints_abs) 3778 #end if 3779 if tmatrix: 3780 res = prim,T 3781 else: 3782 res = prim 3783 #end if 3784 else: 3785 self.error('primitive source "{0}" is not implemented\nplease contact a developer'.format(source)) 3786 #end if 3787 if prim.units!=self.units: 3788 prim.change_units(self.units) 3789 #end if 3790 return res 3791 #end def primitive 3792 3793 3794 def become_primitive(self,source=None,add_kpath=False,**kwargs): 3795 prim = self.primitive(source=source,add_kpath=add_kpath,**kwargs) 3796 self.clone_from(prim) 3797 #end def become_primitive 3798 3799 3800 def add_kpoints(self,kpoints,kweights=None,unique=False,recenter=True,cell_unit=False): 3801 if kweights is None: 3802 kweights = ones((len(kpoints),)) 3803 #end if 3804 if cell_unit: 3805 kpoints = np.dot(array(kpoints),self.kaxes) 3806 #end if 3807 self.kpoints = append(self.kpoints,kpoints,axis=0) 3808 self.kweights = append(self.kweights,kweights) 3809 if unique: 3810 self.unique_kpoints() 3811 #end if 3812 if recenter: 3813 self.recenter_k() #added because qmcpack cannot handle kpoints outside the box 3814 #end if 3815 if self.is_tiled(): 3816 kp,kw = self.kfold(self.tmatrix,kpoints,kweights) 3817 self.folded_structure.add_kpoints(kp,kw,unique=unique) 3818 #end if 3819 #end def add_kpoints 3820 3821 3822 # test needed 3823 def clear_kpoints(self): 3824 self.kpoints = empty((0,self.dim)) 3825 self.kweights = empty((0,)) 3826 if self.folded_structure!=None: 3827 self.folded_structure.clear_kpoints() 3828 #end if 3829 #end def clear_kpoints 3830 3831 3832 def kgrid_from_kspacing(self,kspacing): 3833 kgrid = [] 3834 for ka in self.kaxes: 3835 km = np.linalg.norm(ka) 3836 kg = int(np.ceil(km/kspacing)) 3837 kgrid.append(kg) 3838 #end for 3839 return tuple(kgrid) 3840 #end def kgrid_from_kspacing 3841 3842 3843 def add_kmesh(self,kgrid=None,kshift=None,unique=False,kspacing=None): 3844 if kspacing is not None: 3845 kgrid = self.kgrid_from_kspacing(kspacing) 3846 elif kgrid is None: 3847 self.error('kgrid input is required by add_kmesh') 3848 #end if 3849 self.add_kpoints(kmesh(self.kaxes,kgrid,kshift),unique=unique) 3850 #end def add_kmesh 3851 3852 3853 def add_symmetrized_kmesh(self,kgrid=None,kshift=(0,0,0),kspacing=None): 3854 # find kgrid from kspacing, if requested 3855 if kspacing is not None: 3856 kgrid = self.kgrid_from_kspacing(kspacing) 3857 elif kgrid is None: 3858 self.error('kgrid input is required by add_kmesh') 3859 #end if 3860 3861 # get spglib cell data structure 3862 cell = self.spglib_cell() 3863 3864 # get the symmetry mapping 3865 kmap,kpoints_int = spglib.get_ir_reciprocal_mesh( 3866 kgrid, 3867 cell, 3868 is_shift=kshift 3869 ) 3870 3871 # create the Monkhorst-Pack mesh 3872 kshift = array(kshift,dtype=float) 3873 okgrid = 1.0/array(kgrid,dtype=float) 3874 kpoints = empty(kpoints_int.shape,dtype=float) 3875 for i,ki in enumerate(kpoints_int): 3876 kpoints[i] = (ki+kshift)*okgrid 3877 #end for 3878 kpoints = dot(kpoints,self.kaxes) 3879 3880 # reduce to only the symmetric kpoints with weights 3881 kwmap = obj() 3882 for ik in kmap: 3883 if ik not in kwmap: 3884 kwmap[ik] = 1 3885 else: 3886 kwmap[ik] += 1 3887 #end if 3888 #end for 3889 nkpoints = len(kwmap) 3890 kpoints_symm = empty((nkpoints,self.dim),dtype=float) 3891 kweights_symm = empty((nkpoints,),dtype=float) 3892 n = 0 3893 for ik,kw in kwmap.items(): 3894 kpoints_symm[n] = kpoints[ik] 3895 kweights_symm[n] = kw 3896 n+=1 3897 #end for 3898 self.add_kpoints(kpoints_symm,kweights_symm) 3899 #end def add_symmetrized_kmesh 3900 3901 3902 def kpoints_unit(self,kpoints=None): 3903 if kpoints is None: 3904 kpoints = self.kpoints 3905 #end if 3906 return dot(kpoints,inv(self.kaxes)) 3907 #end def kpoints_unit 3908 3909 3910 def kpoints_reduced(self,kpoints=None): 3911 if kpoints is None: 3912 kpoints = self.kpoints 3913 #end if 3914 return kpoints*self.scale/(2*pi) 3915 #end def kpoints_reduced 3916 3917 3918 def kpoints_qmcpack(self,kpoints=None): 3919 if kpoints is None: 3920 kpoints = self.kpoints.copy() 3921 #end if 3922 kpoints = self.recenter_k(kpoints,kcenter=(0,0,0)) 3923 kpoints = self.kpoints_unit(kpoints) 3924 kpoints = -kpoints 3925 return kpoints 3926 #end def kpoints_qmcpack 3927 3928 3929 # test needed 3930 def inversion_symmetrize_kpoints(self,tol=1e-10,folded=False): 3931 kp = self.kpoints 3932 kaxes = self.kaxes 3933 ntable,dtable = self.neighbor_table(kp,-kp,kaxes,distances=True) 3934 pairs = set() 3935 keep = empty((len(kp),),dtype=bool) 3936 keep[:] = True 3937 for i in range(len(dtable)): 3938 if keep[i] and dtable[i,0]<tol: 3939 j = ntable[i,0] 3940 if j!=i and keep[j]: 3941 keep[j] = False 3942 self.kweights[i] += self.kweights[j] 3943 #end if 3944 #end if 3945 #end for 3946 self.kpoints = self.kpoints[keep] 3947 self.kweights = self.kweights[keep] 3948 if folded and self.folded_structure!=None: 3949 self.folded_structure.inversion_symmetrize_kpoints(tol) 3950 #end if 3951 #end def inversion_symmetrize_kpoints 3952 3953 3954 # test needed 3955 def unique_points(self,points,axes,weights=None,tol=1e-10): 3956 pmap = obj() 3957 npoints = len(points) 3958 if npoints>0: 3959 if weights is None: 3960 weights = ones((npoints,),dtype=int) 3961 #end if 3962 ntable,dtable = self.neighbor_table(points,points,axes,distances=True) 3963 keep = empty((npoints,),dtype=bool) 3964 keep[:] = True 3965 pmo = obj() 3966 for i in range(npoints): 3967 if keep[i]: 3968 pm = [] 3969 jn=0 3970 while jn<npoints and dtable[i,jn]<tol: 3971 j = ntable[i,jn] 3972 pm.append(j) 3973 if j!=i and keep[j]: 3974 keep[j] = False 3975 weights[i] += weights[j] 3976 #end if 3977 jn+=1 3978 #end while 3979 pmo[i] = set(pm) 3980 #end if 3981 #end for 3982 points = points[keep] 3983 weights = weights[keep] 3984 j=0 3985 for i in range(len(keep)): 3986 if keep[i]: 3987 pmap[j] = pmo[i] 3988 j+=1 3989 #end if 3990 #end for 3991 #end if 3992 return points,weights,pmap 3993 #end def unique_points 3994 3995 3996 # test needed 3997 def unique_points_fast(self,points,axes,weights=None,tol=1e-10): 3998 # use an O(N) cell table instead of an O(N^2) neighbor table 3999 pmap = obj() 4000 points = array(points) 4001 axes = array(axes) 4002 npoints = len(points) 4003 if npoints>0: 4004 if weights is None: 4005 weights = ones((npoints,),dtype=int) 4006 else: 4007 weights = array(weights) 4008 #end if 4009 keep = ones((npoints,),dtype=bool) 4010 # place all the points in the box, converted to unit coords 4011 upoints = array(points) 4012 axinv = inv(axes) 4013 for i in range(len(points)): 4014 u = dot(points[i],axinv) 4015 upoints[i] = u-floor(u) 4016 #end for 4017 # create an integer array of cell indices 4018 axmax = -1.0 4019 for a in axes: 4020 axmax = max(axmax,norm(a)) 4021 #end for 4022 # make an integer space corresponding to 1e-7 self.units spatial resolution 4023 cmax = uint64(1e7)*uint64(ceil(axmax)) 4024 ipoints = array(around(cmax*upoints),dtype=uint64) 4025 ipoints[ipoints==cmax] = 0 # make the outer boundary the same as the inner boundary 4026 # load the cell table with point indices 4027 # points in the same cell are identical 4028 ctable = obj() 4029 i=0 4030 for ip in ipoints: 4031 ip = tuple(ip) 4032 if ip not in ctable: 4033 ctable[ip] = i 4034 pmap[i] = [i] 4035 else: 4036 j = ctable[ip] 4037 keep[i] = False 4038 weights[j] += weights[i] 4039 pmap[j].append(i) 4040 #end if 4041 i+=1 4042 #end for 4043 points = points[keep] 4044 weights = weights[keep] 4045 #end if 4046 return points,weights,pmap 4047 #end def unique_points_fast 4048 4049 4050 # test needed 4051 def unique_positions(self,tol=1e-10,folded=False): 4052 pos,weights,pmap = self.unique_points(self.pos,self.axes) 4053 if len(pos)!=len(self.pos): 4054 self.pos = pos 4055 #end if 4056 if folded and self.folded_structure!=None: 4057 self.folded_structure.unique_positions(tol) 4058 #end if 4059 return pmap 4060 #end def unique_positions 4061 4062 4063 # test needed 4064 def unique_kpoints(self,tol=1e-10,folded=False): 4065 kmap = obj() 4066 kp = self.kpoints 4067 if len(kp)>0: 4068 kaxes = self.kaxes 4069 ntable,dtable = self.neighbor_table(kp,kp,kaxes,distances=True) 4070 npoints = len(kp) 4071 keep = empty((len(kp),),dtype=bool) 4072 keep[:] = True 4073 kmo = obj() 4074 for i in range(npoints): 4075 if keep[i]: 4076 km = [] 4077 jn=0 4078 while jn<npoints and dtable[i,jn]<tol: 4079 j = ntable[i,jn] 4080 km.append(j) 4081 if j!=i and keep[j]: 4082 keep[j] = False 4083 self.kweights[i] += self.kweights[j] 4084 #end if 4085 jn+=1 4086 #end while 4087 kmo[i] = set(km) 4088 #end if 4089 #end for 4090 self.kpoints = self.kpoints[keep] 4091 self.kweights = self.kweights[keep] 4092 j=0 4093 for i in range(len(keep)): 4094 if keep[i]: 4095 kmap[j] = kmo[i] 4096 j+=1 4097 #end if 4098 #end for 4099 #end if 4100 if folded and self.folded_structure!=None: 4101 self.folded_structure.unique_kpoints(tol) 4102 #end if 4103 return kmap 4104 #end def unique_kpoints 4105 4106 4107 def kmap(self): 4108 kmap = None 4109 if self.folded_structure!=None: 4110 fs = self.folded_structure 4111 self.kpoints = array(fs.kpoints) 4112 self.kweights = array(fs.kweights) 4113 kmap = self.unique_kpoints() 4114 #end if 4115 return kmap 4116 #end def kmap 4117 4118 4119 # test needed 4120 def select_twist(self,selector='smallest',tol=1e-6): 4121 index = None 4122 invalid_selector = False 4123 if isinstance(selector,str): 4124 if selector=='smallest': 4125 index = (self.kpoints**2).sum(1).argmin() 4126 elif selector=='random': 4127 index = randint(0,len(self.kpoints)-1) 4128 else: 4129 invalid_selector = True 4130 #end if 4131 elif isinstance(selector,(tuple,list,ndarray)): 4132 ku_sel = array(selector,dtype=float) 4133 n = 0 4134 for ku in self.kpoints_unit(): 4135 if norm(ku-ku_sel)<tol: 4136 index = n 4137 break 4138 #end if 4139 n+=1 4140 #end for 4141 if index is None: 4142 self.error('cannot identify twist number\ntwist requested: {0}\ntwists present: {1}'.format(ku_sel,sorted([tuple(k) for k in self.kpoints_unit()]))) 4143 #end if 4144 else: 4145 invalid_selector = True 4146 #end if 4147 if invalid_selector: 4148 self.error('cannot identify twist number\ninvalid selector provided: {0}\nvalid string inputs for selector: smallest, random\nselector can also be a length 3 tuple, list or array (a twist vector)'.format(selector)) 4149 #end if 4150 return index 4151 #end def select_twist 4152 4153 4154 # test needed 4155 def fold_pos(self,large,tol=0.001): 4156 vratio = large.volume()/self.volume() 4157 if abs(vratio-int(around(vratio)))>1e-6: 4158 self.error('cannot fold positions from large cell into current one\nlarge cell volume is not an integer multiple of the current one\nlarge cell volume: {0}\ncurrent cell volume: {1}\nvolume ratio: {2}'.format(large.volume(),self.volume(),vratio)) 4159 T,success = large.tilematrix(self,status=True) 4160 if not success: 4161 self.error('cannot fold positions from large cell into current one\ncells are related by non-integer tilematrix') 4162 #end if 4163 nnearest = int(around(vratio)) 4164 self.elem = large.elem.copy() 4165 self.pos = large.pos.copy() 4166 self.recenter() 4167 nt,dt = self.neighbor_table(distances=True) 4168 nt = nt[:,:nnearest] 4169 dt = dt[:,:nnearest] 4170 if dt.ravel().max()>tol: 4171 self.error('cannot fold positions from large cell into current one\npositions of equivalent atoms are further apart than the tolerance\nmax distance encountered: {0}\ntolerance: {1}'.format(dt.ravel().max(),tol)) 4172 #end if 4173 counts = zeros((len(self.pos),),dtype=int) 4174 for n in nt.ravel(): 4175 counts[n] += 1 4176 #end for 4177 if (counts!=nnearest).any(): 4178 self.error('cannot fold positions from large cell into current one\neach atom must have {0} equivalent positions\nsome atoms found with the following equivalent position counts: {1}'.format(nnearest,counts[counts!=nnearest])) 4179 #end if 4180 ind_visited = set() 4181 neigh_map = obj() 4182 keep = [] 4183 n=0 4184 for nset in nt: 4185 if n not in ind_visited: 4186 neigh_map[n] = nset 4187 keep.append(n) 4188 for ind in nset: 4189 ind_visited.add(ind) 4190 #end for 4191 #end if 4192 n+=1 4193 #end for 4194 if len(ind_visited)!=len(self.pos): 4195 self.error('cannot fold positions from large cell into current one\nsome equivalent atoms could not be identified') 4196 #end if 4197 new_elem = [] 4198 new_pos = [] 4199 for n in keep: 4200 nset = neigh_map[n] 4201 elist = list(set(self.elem[nset])) 4202 if len(elist)!=1: 4203 self.error('cannot fold positions from large cell into current one\nspecies of some equivalent atoms do not match') 4204 #end if 4205 new_elem.append(elist[0]) 4206 new_pos.append(self.pos[nset].mean(0)) 4207 #end for 4208 self.set_elem(new_elem) 4209 self.set_pos(new_pos) 4210 #end def fold_pos 4211 4212 4213 def pos_unit(self,pos=None): 4214 if pos is None: 4215 pos = self.pos 4216 #end if 4217 return dot(pos,inv(self.axes)) 4218 #end def pos_unit 4219 4220 4221 def pos_to_cartesian(self): 4222 self.pos = dot(self.pos,self.axes) 4223 #end def pos_to_cartesian 4224 4225 4226 # test needed 4227 def at_Gpoint(self): 4228 kpu = self.kpoints_unit() 4229 kg = array([0,0,0]) 4230 return len(kpu)==1 and norm(kg-kpu[0])<1e-6 4231 #end def at_Gpoint 4232 4233 4234 # test needed 4235 def at_Lpoint(self): 4236 kpu = self.kpoints_unit() 4237 kg = array([.5,.5,.5]) 4238 return len(kpu)==1 and norm(kg-kpu[0])<1e-6 4239 #end def at_Lpoint 4240 4241 4242 # test needed 4243 def at_real_kpoint(self): 4244 kpu = 2*self.kpoints_unit() 4245 return len(kpu)==1 and abs(kpu-around(kpu)).sum()<1e-6 4246 #end def at_real_kpoint 4247 4248 4249 # test needed 4250 def bonds(self,neighbors,vectors=False): 4251 if self.dim!=3: 4252 self.error('bonds is currently only implemented for 3 dimensions') 4253 #end if 4254 natoms,dim = self.pos.shape 4255 centers = empty((natoms,neighbors,dim)) 4256 distances = empty((natoms,neighbors)) 4257 vect = empty((natoms,neighbors,dim)) 4258 t = self.tile((3,3,3)) 4259 t.recenter(self.center) 4260 nn = nearest_neighbors(neighbors+1,t.pos,self.pos) 4261 for i in range(natoms): 4262 ii = nn[i,0] 4263 n=0 4264 for jj in nn[i,1:]: 4265 p1 = t.pos[ii] 4266 p2 = t.pos[jj] 4267 centers[i,n,:] = (p1+p2)/2 4268 distances[i,n]= sqrt(((p1-p2)**2).sum()) 4269 vect[i,n,:] = p2-p1 4270 n+=1 4271 #end for 4272 #end for 4273 sn = self.copy() 4274 nnr = nn[:,1:].ravel() 4275 sn.set_elem(t.elem[nnr]) 4276 sn.pos = t.pos[nnr] 4277 sn.recenter() 4278 indices = self.locate(sn.pos) 4279 indices = indices.reshape(natoms,neighbors) 4280 if not vectors: 4281 return indices,centers,distances 4282 else: 4283 return indices,centers,distances,vect 4284 #end if 4285 #end def bonds 4286 4287 4288 # test needed 4289 def displacement(self,reference,map=False): 4290 if self.dim!=3: 4291 self.error('displacement is currently only implemented for 3 dimensions') 4292 #end if 4293 ref = reference.tile((3,3,3)) 4294 ref.recenter(reference.center) 4295 rmap = array(3**3*list(range(len(reference.pos)),dtype=int)) 4296 nn = nearest_neighbors(1,ref.pos,self.pos).ravel() 4297 displacement = self.pos - ref.pos[nn] 4298 if not map: 4299 return displacement 4300 else: 4301 return displacement,rmap[nn] 4302 #end if 4303 #end def displacement 4304 4305 4306 # test needed 4307 def scalar_displacement(self,reference): 4308 return sqrt((self.displacement(reference)**2).sum(1)) 4309 #end def scalar_displacement 4310 4311 4312 # test needed 4313 def distortion(self,reference,neighbors): 4314 if self.dim!=3: 4315 self.error('distortion is currently only implemented for 3 dimensions') 4316 #end if 4317 if reference.volume()/self.volume() < 1.1: 4318 ref = reference.tile((3,3,3)) 4319 ref.recenter(reference.center) 4320 else: 4321 ref = reference 4322 #end if 4323 rbi,rbc,rbl,rbv = ref.bonds(neighbors,vectors=True) 4324 sbi,sbc,sbl,sbv = self.bonds(neighbors,vectors=True) 4325 nn = nearest_neighbors(1,reference.pos,self.pos).ravel() 4326 distortion = empty(sbv.shape) 4327 magnitude = empty((len(self.pos),)) 4328 for i in range(len(self.pos)): 4329 ir = nn[i] 4330 bonds = sbv[i] 4331 rbonds = rbv[ir] 4332 ib = empty((neighbors,),dtype=int) 4333 ibr = empty((neighbors,),dtype=int) 4334 r = list(range(neighbors)) 4335 rr = list(range(neighbors)) 4336 for n in range(neighbors): 4337 mindist = 1e99 4338 ibmin = -1 4339 ibrmin = -1 4340 for nb in r: 4341 for nbr in rr: 4342 d = norm(bonds[nb]-rbonds[nbr]) 4343 if d<mindist: 4344 mindist=d 4345 ibmin=nb 4346 ibrmin=nbr 4347 #end if 4348 #end for 4349 #end for 4350 ib[n]=ibmin 4351 ibr[n]=ibrmin 4352 r.remove(ibmin) 4353 rr.remove(ibrmin) 4354 #end for 4355 #end for 4356 d = bonds[ib]-rbonds[ibr] 4357 distortion[i] = d 4358 magnitude[i] = (sqrt((d**2).sum(axis=1))).sum() 4359 #end for 4360 return distortion,magnitude 4361 #end def distortion 4362 4363 4364 # test needed 4365 def bond_compression(self,reference,neighbors): 4366 ref = reference 4367 rbi,rbc,rbl = ref.bonds(neighbors) 4368 sbi,sbc,sbl = self.bonds(neighbors) 4369 bondlen = rbl.mean() 4370 return abs(1.-sbl/bondlen).max(axis=1) 4371 #end def bond_compression 4372 4373 4374 # test needed 4375 def boundary(self,dims=(0,1,2),dtol=1e-6): 4376 dim_eff = len(dims) 4377 natoms,dim = self.pos.shape 4378 bdims = array(dim*[False]) 4379 for d in dims: 4380 bdims[d] = True 4381 #end for 4382 p = self.pos[:,bdims] 4383 indices = convex_hull(p,dim_eff,dtol) 4384 return indices 4385 #end def boundary 4386 4387 4388 def embed(self,small,dims=(0,1,2),dtol=1e-6,utol=1e-6): 4389 small = small.copy() 4390 small.recenter() 4391 center = array(self.center) 4392 self.recenter(small.center) 4393 bind = small.boundary(dims,dtol) 4394 bpos = small.pos[bind] 4395 belem= small.elem[bind] 4396 nn = nearest_neighbors(1,self.pos,bpos).ravel() 4397 mpos = self.pos[nn] 4398 dr = (mpos-bpos).mean(0) 4399 for i in range(len(bpos)): 4400 bpos[i]+=dr 4401 #end for 4402 dmax = sqrt(((mpos-bpos)**2).sum(1)).max() 4403 for i in range(len(small.pos)): 4404 small.pos[i]+=dr 4405 #end for 4406 ins,surface = small.inside(self.pos,tol=utol,separate=True) 4407 replaced = empty((len(self.pos),),dtype=bool) 4408 replaced[:] = False 4409 inside = replaced.copy() 4410 inside[ins] = True 4411 nn = nearest_neighbors(1,self.pos,small.pos).ravel() 4412 elist = list(self.elem) 4413 plist = list(self.pos) 4414 pos = small.pos 4415 elem = small.elem 4416 for i in range(len(pos)): 4417 n = nn[i] 4418 if not replaced[n]: 4419 elist[n] = elem[i] 4420 plist[n] = pos[i] 4421 replaced[n] = True 4422 else: 4423 elist.append(elem[i]) 4424 plist.append(pos[i]) 4425 #end if 4426 #end for 4427 remove = arange(len(self.pos))[inside & logical_not(replaced)] 4428 remove.sort() 4429 remove = flipud(remove) 4430 for i in remove: 4431 elist.pop(i) 4432 plist.pop(i) 4433 #end for 4434 self.set_elem(elist) 4435 self.pos = array(plist) 4436 self.recenter(center) 4437 return dmax 4438 #end def embed 4439 4440 4441 # test needed 4442 def shell(self,cell,neighbors,direction='in'): 4443 if self.dim!=3: 4444 self.error('shell is currently only implemented for 3 dimensions') 4445 #end if 4446 dd = {'in':equate,'out':negate} 4447 dir = dd[direction] 4448 natoms,dim=self.pos.shape 4449 ncells=3**3 4450 ntile = ncells*natoms 4451 pos = empty((ntile,dim)) 4452 ind = empty((ntile,),dtype=int) 4453 oind = list(range(natoms)) 4454 for nt in range(ncells): 4455 n=nt*natoms 4456 ind[n:n+natoms]=oind[:] 4457 pos[n:n+natoms]=self.pos[:] 4458 #end for 4459 nt=0 4460 for k in -1,0,1: 4461 for j in -1,0,1: 4462 for i in -1,0,1: 4463 iv = array([[i,j,k]]) 4464 v = dot(iv,self.axes) 4465 for d in range(dim): 4466 ns = nt*natoms 4467 ne = ns+natoms 4468 pos[ns:ne,d] += v[0,d] 4469 #end for 4470 nt+=1 4471 #end for 4472 #end for 4473 #end for 4474 4475 inside = empty(ntile,) 4476 inside[:]=False 4477 ins = cell.inside(pos) 4478 inside[ins]=True 4479 4480 iishell = set() 4481 nn = nearest_neighbors(neighbors,pos) 4482 for ii in range(len(nn)): 4483 for jj in nn[ii]: 4484 in1 = inside[ii] 4485 in2 = inside[jj] 4486 if dir(in1 and not in2): 4487 iishell.add(ii) 4488 #end if 4489 if dir(in2 and not in1): 4490 iishell.add(jj) 4491 #end if 4492 #end if 4493 #end if 4494 ishell = ind[list(iishell)] 4495 return ishell 4496 #end def shell 4497 4498 4499 def interpolate(self,other,images,min_image=True,recenter=True,match_com=False,chained=False): 4500 s1 = self.copy() 4501 s2 = other.copy() 4502 s1.remove_folded() 4503 s2.remove_folded() 4504 if s2.units!=s1.units: 4505 s2.change_units(s1.units) 4506 #end if 4507 if (s1.elem!=s2.elem).any(): 4508 self.error('cannot interpolate structures, atoms do not match\n atoms1: {0}\n atoms2: {1}'.format(s1.elem,s2.elem)) 4509 #end if 4510 structures = [] 4511 npath = images+2 4512 c1 = s1.center 4513 c2 = s2.center 4514 ax1 = s1.axes 4515 ax2 = s2.axes 4516 pos1 = s1.pos 4517 pos2 = s2.pos 4518 min_image &= abs(ax1-ax2).max()<1e-6 4519 if min_image: 4520 dp = self.min_image_vectors(pos1,pos2,ax1,pairs=False) 4521 pos2 = pos1 + dp 4522 #end if 4523 if match_com: 4524 com1 = pos1.mean(axis=0) 4525 com2 = pos2.mean(axis=1) 4526 dcom = com1-com2 4527 for n in range(len(pos2)): 4528 pos2[n] += dcom 4529 #end for 4530 if chained: 4531 other.pos = pos2 4532 #end if 4533 #end if 4534 for n in range(npath): 4535 f1 = 1.-float(n)/(npath-1) 4536 f2 = 1.-f1 4537 center = f1*c1 + f2*c2 4538 axes = f1*ax1 + f2*ax2 4539 pos = f1*pos1 + f2*pos2 4540 s = s1.copy() 4541 s.reset_axes(axes) 4542 s.center = center 4543 s.pos = pos 4544 if recenter: 4545 s.recenter() 4546 #end if 4547 structures.append(s) 4548 #end for 4549 return structures 4550 #end def interpolate 4551 4552 4553 # returns madelung potential constant v_M 4554 # see equation 7 in PRB 78 125106 (2008) 4555 def madelung(self,axes=None,tol=1e-10): 4556 if self.dim!=3: 4557 self.error('madelung is currently only implemented for 3 dimensions') 4558 #end if 4559 if axes is None: 4560 a = self.axes.T.copy() 4561 else: 4562 a = axes.T.copy() 4563 #end if 4564 if self.units!='B': 4565 a = convert(a,self.units,'B') 4566 #end if 4567 volume = abs(det(a)) 4568 b = 2*pi*inv(a).T 4569 rconv = 8*(3.*volume/(4*pi))**(1./3) 4570 kconv = 2*pi/rconv 4571 gconst = -1./(4*kconv**2) 4572 vmc = -pi/(kconv**2*volume)-2*kconv/sqrt(pi) 4573 4574 nshells = 20 4575 vshell = [0.] 4576 p = Sobj() 4577 m = Sobj() 4578 for n in range(1,nshells+1): 4579 i = mgrid[-n:n+1,-n:n+1,-n:n+1] 4580 i = i.reshape(3,(2*n+1)**3) 4581 R = sqrt((dot(a,i)**2).sum(0)) 4582 G2 = (dot(b,i)**2).sum(0) 4583 4584 izero = n + n*(2*n+1) + n*(2*n+1)**2 4585 4586 p.R = R[0:izero] 4587 p.G2 = G2[0:izero] 4588 m.R = R[izero+1:] 4589 m.G2 = G2[izero+1:] 4590 domains = [p,m] 4591 vshell.append(0.) 4592 for d in domains: 4593 vshell[n] += (erfc(kconv*d.R)/d.R).sum() + 4*pi/volume*(exp(gconst*d.G2)/d.G2).sum() 4594 #end for 4595 if abs(vshell[n]-vshell[n-1])<tol: 4596 break 4597 #end if 4598 #end for 4599 vm = vmc + vshell[-1] 4600 4601 if axes is None: 4602 self.Vmadelung = vm 4603 #end if 4604 return vm 4605 #end def madelung 4606 4607 4608 def makov_payne(self,q=1,eps=1.0,units='Ha',order=1): 4609 if order!=1: 4610 self.error('Only first order Makov-Payne correction is currently supported.') 4611 #end if 4612 if 'Vmadelung' not in self: 4613 vm = self.madelung() 4614 else: 4615 vm = self.Vmadelung 4616 #end if 4617 mp = -0.5*q**2*vm/eps 4618 if units!='Ha': 4619 mp = convert(mp,'Ha',units) 4620 #end if 4621 return mp 4622 #end def makov_payne 4623 4624 4625 def read(self,filepath,format=None,elem=None,block=None,grammar='1.1',cell='prim',contents=False): 4626 if os.path.exists(filepath): 4627 path,file = os.path.split(filepath) 4628 if format is None: 4629 if '.' in file: 4630 name,format = file.rsplit('.',1) 4631 elif file.lower().endswith('poscar'): 4632 format = 'poscar' 4633 else: 4634 self.error('file format could not be determined\nunrecognized file: {0}'.format(filepath)) 4635 #end if 4636 #end if 4637 elif not contents: 4638 self.error('file does not exist: {0}'.format(filepath)) 4639 #end if 4640 if format is None: 4641 self.error('file format must be provided') 4642 #end if 4643 self.mag = None 4644 self.frozen = None 4645 format = format.lower() 4646 if format=='xyz': 4647 self.read_xyz(filepath) 4648 elif format=='xsf': 4649 self.read_xsf(filepath) 4650 elif format=='poscar': 4651 self.read_poscar(filepath,elem=elem) 4652 elif format=='cif': 4653 self.read_cif(filepath,block=block,grammar=grammar,cell=cell) 4654 elif format=='fhi-aims': 4655 self.read_fhi_aims(filepath) 4656 else: 4657 self.error('cannot read structure from file\nunsupported file format: {0}'.format(format)) 4658 #end if 4659 if self.has_axes(): 4660 self.set_bconds('ppp') 4661 #end if 4662 #end def read 4663 4664 4665 def read_xyz(self,filepath): 4666 elem = [] 4667 pos = [] 4668 if os.path.exists(filepath): 4669 lines = open(filepath,'r').read().strip().splitlines() 4670 else: 4671 lines = filepath.strip().splitlines() # "filepath" is file contents 4672 #end if 4673 if len(lines)>1: 4674 ntot = int(lines[0].strip()) 4675 natoms = 0 4676 e = None 4677 p = None 4678 try: 4679 tokens = lines[1].split() 4680 if len(tokens)==4: 4681 e = tokens[0] 4682 p = array(tokens[1:],float) 4683 #end if 4684 except: 4685 None 4686 #end try 4687 if p is not None: 4688 elem.append(e) 4689 pos.append(p) 4690 natoms+=1 4691 #end if 4692 if len(lines)>2: 4693 for l in lines[2:]: 4694 tokens = l.split() 4695 if len(tokens)==4: 4696 elem.append(tokens[0]) 4697 pos.append(array(tokens[1:],float)) 4698 natoms+=1 4699 if natoms==ntot: 4700 break 4701 #end if 4702 #end if 4703 #end for 4704 #end if 4705 if natoms!=ntot: 4706 self.error('xyz file read failed\nattempted to read file: {0}\nnumber of atoms expected: {1}\nnumber of atoms found: {2}'.format(filepath,ntot,natoms)) 4707 #end if 4708 #end if 4709 self.dim = 3 4710 self.set_elem(elem) 4711 self.pos = array(pos) 4712 self.units = 'A' 4713 #end def read_xyz 4714 4715 4716 def read_xsf(self,filepath): 4717 if isinstance(filepath,XsfFile): 4718 f = filepath 4719 elif os.path.exists(filepath): 4720 f = XsfFile(filepath) 4721 else: 4722 f = XsfFile() 4723 f.read_text(filepath) # "filepath" is file contents 4724 #end if 4725 elem = [] 4726 for n in f.elem: 4727 if isinstance(n,str): 4728 elem.append(n) 4729 else: 4730 elem.append(pt.simple_elements[n].symbol) 4731 #end if 4732 #end for 4733 self.dim = 3 4734 self.units = 'A' 4735 self.reset_axes(f.primvec) 4736 self.set_elem(elem) 4737 self.pos = f.pos 4738 #end def read_xsf 4739 4740 4741 def read_poscar(self,filepath,elem=None): 4742 if os.path.exists(filepath): 4743 lines = open(filepath,'r').read().splitlines() 4744 else: 4745 lines = filepath.splitlines() # "filepath" is file contents 4746 #end if 4747 nlines = len(lines) 4748 min_lines = 8 4749 if nlines<min_lines: 4750 self.error('POSCAR file must have at least {0} lines\n only {1} lines found'.format(min_lines,nlines)) 4751 #end if 4752 dim = 3 4753 scale = float(lines[1].strip()) 4754 axes = empty((dim,dim)) 4755 axes[0] = array(lines[2].split(),dtype=float) 4756 axes[1] = array(lines[3].split(),dtype=float) 4757 axes[2] = array(lines[4].split(),dtype=float) 4758 if scale<0.0: 4759 scale = abs(scale)/det(axes) 4760 #end if 4761 axes = scale*axes 4762 tokens = lines[5].split() 4763 if tokens[0].isdigit(): 4764 counts = array(tokens,dtype=int) 4765 if elem is None: 4766 self.error('variable elem must be provided to read_poscar() to assign atomic species to positions for POSCAR format') 4767 elif len(elem)!=len(counts): 4768 self.error('one elem must be given for each element count in the POSCAR file\n number of elem counts: {0}\n number of elem given: {1}'.format(len(counts),len(elem))) 4769 #end if 4770 lcur = 6 4771 else: 4772 elem = tokens 4773 counts = array(lines[6].split(),dtype=int) 4774 lcur = 7 4775 #end if 4776 species = elem 4777 # relabel species that have multiple occurances 4778 sset = set(species) 4779 for spec in sset: 4780 if species.count(spec)>1: 4781 cnt=0 4782 for n in range(len(species)): 4783 specn = species[n] 4784 if specn==spec: 4785 cnt+=1 4786 species[n] = specn+str(cnt) 4787 #end if 4788 #end for 4789 #end if 4790 #end for 4791 elem = [] 4792 for i in range(len(counts)): 4793 elem.extend(counts[i]*[species[i]]) 4794 #end for 4795 self.dim = dim 4796 self.units = 'A' 4797 self.reset_axes(axes) 4798 4799 if lcur<len(lines) and len(lines[lcur])>0: 4800 c = lines[lcur].lower().strip()[0] 4801 lcur+=1 4802 else: 4803 return 4804 #end if 4805 selective_dynamics = c=='s' 4806 if selective_dynamics: # Selective dynamics 4807 if lcur<len(lines) and len(lines[lcur])>0: 4808 c = lines[lcur].lower().strip()[0] 4809 lcur+=1 4810 else: 4811 return 4812 #end if 4813 #end if 4814 cartesian = c=='c' or c=='k' 4815 npos = counts.sum() 4816 if lcur+npos>len(lines): 4817 return 4818 #end if 4819 spos = [] 4820 for i in range(npos): 4821 spos.append(lines[lcur+i].split()) 4822 #end for 4823 spos = array(spos) 4824 pos = array(spos[:,0:3],dtype=float) 4825 if cartesian: 4826 pos = scale*pos 4827 else: 4828 pos = dot(pos,axes) 4829 #end if 4830 self.set_elem(elem) 4831 self.pos = pos 4832 if selective_dynamics or spos.shape[1]>3: 4833 move = array(spos[:,3:6],dtype=str) 4834 self.freeze(list(range(self.size())),directions=move=='F') 4835 #end if 4836 #end def read_poscar 4837 4838 4839 def read_cif(self,filepath,block=None,grammar='1.1',cell='prim'): 4840 axes,elem,pos,units = read_cif(filepath,block,grammar,cell,args_only=True) 4841 self.dim = 3 4842 self.set_axes(axes) 4843 self.set_elem(elem) 4844 self.pos = pos 4845 self.units = units 4846 #end def read_cif 4847 4848 4849 # test needed 4850 def read_fhi_aims(self,filepath): 4851 if os.path.exists(filepath): 4852 lines = open(filepath,'r').read().splitlines() 4853 else: 4854 lines = filepath.splitlines() # "filepath" is contents 4855 #end if 4856 axes = [] 4857 pos = [] 4858 elem = [] 4859 constrain_relax = [] 4860 unit_pos = False 4861 for line in lines: 4862 ls = line.strip() 4863 if len(ls)>0 and ls[0]!='#': 4864 tokens = ls.split() 4865 t0 = tokens[0] 4866 if t0=='lattice_vector': 4867 axes.append(tokens[1:]) 4868 elif t0=='atom_frac': 4869 pos.append(tokens[1:4]) 4870 elem.append(tokens[4]) 4871 unit_pos = True 4872 elif t0=='atom': 4873 pos.append(tokens[1:4]) 4874 elem.append(tokens[4]) 4875 elif t0=='constrain_relaxation': 4876 constrain_relax.append(tokens[1]) 4877 elif t0.startswith('initial'): 4878 None 4879 else: 4880 #None 4881 self.error('unrecogonized or not yet supported token in fhi-aims geometry file: {0}'.format(t0)) 4882 #end if 4883 #end if 4884 #end for 4885 axes = array(axes,dtype=float) 4886 pos = array(pos,dtype=float) 4887 if unit_pos: 4888 pos = dot(pos,axes) 4889 #end if 4890 self.dim = 3 4891 if len(axes)>0: 4892 self.set_axes(axes) 4893 #end if 4894 self.set_elem(elem) 4895 self.pos = pos 4896 self.units = 'A' 4897 if len(constrain_relax)>0: 4898 constrain_relax = array(constrain_relax) 4899 self.freeze(list(range(self.size())),directions=constrain_relax=='.true.') 4900 #end if 4901 #end def read_fhi_aims 4902 4903 4904 def write(self,filepath=None,format=None): 4905 if filepath is None and format is None: 4906 self.error('please specify either the filepath or format arguments to write()') 4907 elif format is None: 4908 if '.' in filepath: 4909 format = filepath.split('.')[-1] 4910 else: 4911 self.error('file format could not be determined\neither request the format directly with the format keyword or add a file format extension to the file name') 4912 #end if 4913 #end if 4914 format = format.lower() 4915 if format=='xyz': 4916 c = self.write_xyz(filepath) 4917 elif format=='xsf': 4918 c = self.write_xsf(filepath) 4919 elif format=='poscar': 4920 c = self.write_poscar(filepath) 4921 elif format=='fhi-aims': 4922 c = self.write_fhi_aims(filepath) 4923 else: 4924 self.error('file format {0} is unrecognized'.format(format)) 4925 #end if 4926 return c 4927 #end def write 4928 4929 4930 def write_xyz(self,filepath=None,header=True,units='A'): 4931 if self.dim!=3: 4932 self.error('write_xyz is currently only implemented for 3 dimensions') 4933 #end if 4934 s = self.copy() 4935 s.change_units(units) 4936 c='' 4937 if header: 4938 c += str(len(s.elem))+'\n\n' 4939 #end if 4940 for i in range(len(s.elem)): 4941 e = s.elem[i] 4942 p = s.pos[i] 4943 c+=' {0:2} {1:12.8f} {2:12.8f} {3:12.8f}\n'.format(e,p[0],p[1],p[2]) 4944 #end for 4945 if filepath!=None: 4946 open(filepath,'w').write(c) 4947 #end if 4948 return c 4949 #end def write_xyz 4950 4951 4952 def write_xsf(self,filepath=None): 4953 if self.dim!=3: 4954 self.error('write_xsf is currently only implemented for 3 dimensions') 4955 #end if 4956 s = self.copy() 4957 s.change_units('A') 4958 c = ' CRYSTAL\n' 4959 c += ' PRIMVEC\n' 4960 for a in s.axes: 4961 c += ' {0:12.8f} {1:12.8f} {2:12.8f}\n'.format(*a) 4962 #end for 4963 c += ' PRIMCOORD\n' 4964 c += ' {0} 1\n'.format(len(s.elem)) 4965 for i in range(len(s.elem)): 4966 e = s.elem[i] 4967 identified = e in pt.elements 4968 if not identified: 4969 if len(e)>2: 4970 e = e[0:2] 4971 elif len(e)==2: 4972 e = e[0:1] 4973 #end if 4974 identified = e in pt.elements 4975 #end if 4976 if not identified: 4977 self.error('{0} is not an element\nxsf file cannot be written'.format(e)) 4978 #end if 4979 enum = pt.elements[e].atomic_number 4980 r = s.pos[i] 4981 c += ' {0:>3} {1:12.8f} {2:12.8f} {3:12.8f}\n'.format(enum,r[0],r[1],r[2]) 4982 #end for 4983 if filepath!=None: 4984 open(filepath,'w').write(c) 4985 #end if 4986 return c 4987 #end def write_xsf 4988 4989 4990 def write_poscar(self,filepath=None): 4991 s = self.copy() 4992 s.change_units('A') 4993 species,species_count = s.order_by_species() 4994 poscar = PoscarFile() 4995 poscar.scale = 1.0 4996 poscar.axes = s.axes 4997 poscar.elem = species 4998 poscar.elem_count = species_count 4999 poscar.coord = 'cartesian' 5000 poscar.pos = s.pos 5001 c = poscar.write_text() 5002 if filepath is not None: 5003 open(filepath,'w').write(c) 5004 #end if 5005 return c 5006 #end def write_poscar 5007 5008 5009 # test needed 5010 def write_fhi_aims(self,filepath=None): 5011 s = self.copy() 5012 s.change_units('A') 5013 c = '' 5014 c+='\n' 5015 for a in s.axes: 5016 c += 'lattice_vector {0: 12.8f} {1: 12.8f} {2: 12.8f}\n'.format(*a) 5017 #end for 5018 c+='\n' 5019 for p,e in zip(self.pos,self.elem): 5020 c += 'atom_frac {0: 12.8f} {1: 12.8f} {2: 12.8f} {3}\n'.format(p[0],p[1],p[2],e) 5021 #end for 5022 if filepath!=None: 5023 open(filepath,'w').write(c) 5024 #end if 5025 return c 5026 #end def write_fhi_aims 5027 5028 5029 def plot2d_ax(self,ix,iy,*args,**kwargs): 5030 if self.dim!=3: 5031 self.error('plot2d_ax is currently only implemented for 3 dimensions') 5032 #end if 5033 iz = list(set([0,1,2])-set([ix,iy]))[0] 5034 ax = self.axes.copy() 5035 a = self.axes[iz] 5036 dc = self.center-ax.sum(0)/2 5037 pp = array([0*a,ax[ix],ax[ix]+ax[iy],ax[iy],0*a]) 5038 for i in range(len(pp)): 5039 pp[i]+=dc 5040 pp[i]-=dot(a,pp[i])/dot(a,a)*a 5041 #end for 5042 plot(pp[:,ix],pp[:,iy],*args,**kwargs) 5043 #end def plot2d_ax 5044 5045 5046 def plot2d_pos(self,ix,iy,*args,**kwargs): 5047 if self.dim!=3: 5048 self.error('plot2d_pos is currently only implemented for 3 dimensions') 5049 #end if 5050 iz = list(set([0,1,2])-set([ix,iy]))[0] 5051 pp = self.pos.copy() 5052 a = self.axes[iz] 5053 for i in range(len(pp)): 5054 pp[i] -= dot(a,pp[i])/dot(a,a)*a 5055 #end for 5056 plot(pp[:,ix],pp[:,iy],*args,**kwargs) 5057 #end def plot2d_pos 5058 5059 5060 def plot2d(self,pos_style='b.',ax_style='k-'): 5061 if self.dim!=3: 5062 self.error('plot2d is currently only implemented for 3 dimensions') 5063 #end if 5064 subplot(1,3,1) 5065 self.plot2d_ax(0,1,ax_style,lw=2) 5066 self.plot2d_pos(0,1,pos_style) 5067 title('a1,a2') 5068 subplot(1,3,2) 5069 self.plot2d_ax(1,2,ax_style,lw=2) 5070 self.plot2d_pos(1,2,pos_style) 5071 title('a2,a3') 5072 subplot(1,3,3) 5073 self.plot2d_ax(2,0,ax_style,lw=2) 5074 self.plot2d_pos(2,0,pos_style) 5075 title('a3,a1') 5076 #end def plot2d 5077 5078 5079 def plot2d_kax(self,ix,iy,*args,**kwargs): 5080 if self.dim!=3: 5081 self.error('plot2d_ax is currently only implemented for 3 dimensions') 5082 #end if 5083 iz = list(set([0,1,2])-set([ix,iy]))[0] 5084 ax = self.kaxes.copy() 5085 a = ax[iz] 5086 dc = 0*a 5087 pp = array([0*a,ax[ix],ax[ix]+ax[iy],ax[iy],0*a]) 5088 for i in range(len(pp)): 5089 pp[i]+=dc 5090 pp[i]-=dot(a,pp[i])/dot(a,a)*a 5091 #end for 5092 plot(pp[:,ix],pp[:,iy],*args,**kwargs) 5093 #end def plot2d_kax 5094 5095 5096 def plot2d_kp(self,ix,iy,*args,**kwargs): 5097 if self.dim!=3: 5098 self.error('plot2d_kp is currently only implemented for 3 dimensions') 5099 #end if 5100 iz = list(set([0,1,2])-set([ix,iy]))[0] 5101 pp = self.kpoints.copy() 5102 a = self.kaxes[iz] 5103 for i in range(len(pp)): 5104 pp[i] -= dot(a,pp[i])/dot(a,a)*a 5105 #end for 5106 plot(pp[:,ix],pp[:,iy],*args,**kwargs) 5107 #end def plot2d_kp 5108 5109 5110 def show(self,viewer='vmd',filepath='/tmp/tmp.xyz'): 5111 if self.dim!=3: 5112 self.error('show is currently only implemented for 3 dimensions') 5113 #end if 5114 self.write_xyz(filepath) 5115 os.system(viewer+' '+filepath) 5116 #end def show 5117 5118 5119 # minimal ASE Atoms-like interface to Structure objects for spglib 5120 def get_cell(self): 5121 return self.axes.copy() 5122 #end def get_cell 5123 5124 def get_scaled_positions(self): 5125 return self.pos_unit() 5126 #end def get_scaled_positions 5127 5128 def get_number_of_atoms(self): 5129 return len(self.elem) 5130 #end def get_number_of_atoms 5131 5132 def get_atomic_numbers(self): 5133 an = [] 5134 for e in self.elem: 5135 iselem,esymb = is_element(e,symbol=True) 5136 if not iselem: 5137 self.error('Atomic symbol, {}, not recognized'.format(esymb)) 5138 else: 5139 an.append(ptable[esymb].atomic_number) 5140 #end if 5141 #end for 5142 return array(an,dtype='intc') 5143 #end def get_atomic_numbers 5144 5145 def get_magnetic_moments(self): 5146 self.error('structure objects do not currently support magnetic moments') 5147 #end def get_magnetic_moments 5148 5149 5150 # direct spglib interface 5151 def spglib_cell(self): 5152 lattice = self.axes.copy() 5153 positions = self.pos_unit() 5154 numbers = self.get_atomic_numbers() 5155 cell = (lattice,positions,numbers) 5156 return cell 5157 #end def spglib_cell 5158 5159 5160 def get_symmetry(self,symprec=1e-5): 5161 cell = self.spglib_cell() 5162 return spglib.get_symmetry(cell,symprec=symprec) 5163 #end def get_symmetry 5164 5165 5166 def get_symmetry_dataset(self,symprec=1e-5, angle_tolerance=-1.0, hall_number=0): 5167 cell = self.spglib_cell() 5168 ds = spglib.get_symmetry_dataset( 5169 cell, 5170 symprec = symprec, 5171 angle_tolerance = angle_tolerance, 5172 hall_number = hall_number, 5173 ) 5174 return ds 5175 #end def get_symmetry 5176 5177 5178 # functions based on direct spglib interface 5179 5180 def symmetry_data(self,*args,**kwargs): 5181 ds = self.get_symmetry_dataset(*args,**kwargs) 5182 ds = obj(ds) 5183 for k,v in ds.items(): 5184 if isinstance(v,dict): 5185 ds[k] = obj(v) 5186 #end if 5187 #end for 5188 return ds 5189 #end def symmetry_data 5190 5191 5192 def bravais_lattice_name(self,symm_data=None): 5193 if symm_data is None: 5194 symm_data = self.symmetry_data() 5195 #end if 5196 sg = symm_data.number 5197 name = symm_data.international 5198 if not isinstance(sg,int) or sg<1 or sg>230: 5199 self.error('Invalid space group from spglib: {}'.format(sg)) 5200 #end if 5201 if not isinstance(name,str): 5202 self.error('Invalid space group name from spglib: {}'.format(name)) 5203 #end if 5204 bv = None 5205 if sg>=1 and sg<=2: 5206 bv = 'triclinic_'+name[0] 5207 elif sg>=3 and sg<=15: 5208 bv = 'monoclinic_'+name[0] 5209 elif sg>=16 and sg<=74: 5210 bv = 'orthorhombic_'+name[0] 5211 elif sg>=75 and sg<=142: 5212 bv = 'tetragonal_'+name[0] 5213 elif sg>=143 and sg<=167: 5214 bv = 'trigonal_'+name[0] 5215 elif sg>=168 and sg<=194: 5216 bv = 'hexagonal_'+name[0] 5217 elif sg>=195 and sg<=230: 5218 bv = 'cubic_'+name[0] 5219 #end if 5220 if bv is None: 5221 self.error('Bravais lattice could not be determined.\nSpace group number and name: {} {}'.format(sg,name)) 5222 #end if 5223 return bv 5224 #end def bravais_lattice_name 5225 5226 5227 # test needed 5228 def space_group_operations(self,tol=1e-5,unit=False): 5229 ds = self.get_symmetry(symprec=tol) 5230 if ds is None: 5231 self.error('Symmetry search failed.\nspglib error message:\n{}'.format(spglib.get_error_message())) 5232 #end if 5233 ds = obj(ds) 5234 rotations = ds.rotations 5235 translations = ds.translations 5236 5237 if not unit: 5238 # Transform to Cartesian 5239 axes = self.axes 5240 axinv = np.linalg.inv(axes) 5241 for n,(R,t) in enumerate(zip(rotations,translations)): 5242 rotations[n] = np.dot(axinv,np.dot(R,axes)) 5243 translations[n] = np.dot(t,axes) 5244 #end for 5245 #end if 5246 5247 return rotations,translations 5248 #end def space_group_operations 5249 5250 5251 def point_group_operations(self,tol=1e-5,unit=False): 5252 rotations,translations = self.space_group_operations(tol=tol,unit=unit) 5253 no_trans = translations.max(axis=1) < tol 5254 return rotations[no_trans] 5255 #end def point_group_operations 5256 5257 5258 def check_point_group_operations(self,rotations=None,tol=1e-5,unit=False,dtol=1e-5,ncheck=1,exit=False): 5259 if rotations is None: 5260 rotations = self.point_group_operations(tol=tol,unit=unit) 5261 #ned if 5262 r = self.pos 5263 if ncheck=='all': 5264 ncheck = len(r) 5265 #end if 5266 all_same = True 5267 for n in range(ncheck): 5268 rc = r[n] 5269 for R in rotations: 5270 rp = np.dot(r-rc,R)+rc 5271 dt = self.min_image_distances(r,rp) 5272 same = True 5273 for d in dt: 5274 same &= dt.min()<dtol 5275 #end for 5276 all_same &= same 5277 #end for 5278 #end for 5279 if not all_same and exit: 5280 self.error('Point group operators are not all symmetries of the structure.') 5281 #end if 5282 return all_same 5283 #end def check_point_group_operations 5284 5285 5286 def equivalent_atoms(self): 5287 ds = self.symmetry_data() 5288 5289 # collect sets of species labels 5290 species_by_specnum = obj() 5291 for e,sn in zip(self.elem,ds.equivalent_atoms): 5292 is_elem,es = is_element(e,symbol=True) 5293 if sn not in species_by_specnum: 5294 species_by_specnum[sn] = set() 5295 #end if 5296 species_by_specnum[sn].add(es) 5297 #end for 5298 for sn,sset in species_by_specnum.items(): 5299 if len(sset)>1: 5300 self.error('Cannot find equivalent atoms.\nMultiple atomic species were marked as being equivalent.\nSpecies marked in this way: {}'.format(list(sset))) 5301 #end if 5302 species_by_specnum[sn] = list(sset)[0] 5303 #end for 5304 5305 # give each unique species a unique label 5306 labels_by_specnum = obj() 5307 species_list = list(species_by_specnum.values()) 5308 species_set = set(species_list) 5309 species_counts = obj() 5310 for s in species_set: 5311 species_counts[s] = species_list.count(s) 5312 #end for 5313 spec_counts = obj() 5314 for sn,s in species_by_specnum.items(): 5315 if species_counts[s]==1: 5316 labels_by_specnum[sn] = s 5317 else: 5318 if s not in spec_counts: 5319 spec_counts[s] = 1 5320 else: 5321 spec_counts[s] += 1 5322 #end if 5323 labels_by_specnum[sn] = s + str(spec_counts[s]) 5324 #end if 5325 #end for 5326 5327 # find indices for each unique species 5328 equiv_indices = obj() 5329 for s in labels_by_specnum.values(): 5330 equiv_indices[s] = list() 5331 #end for 5332 for i,sn in enumerate(ds.equivalent_atoms): 5333 equiv_indices[labels_by_specnum[sn]].append(i) 5334 #end for 5335 for s,indices in equiv_indices.items(): 5336 equiv_indices[s] = np.array(indices,dtype=int) 5337 #end for 5338 5339 return equiv_indices 5340 #end def equivalent_atoms 5341 5342 5343 # operations to support restricted cases for RMG code 5344 5345 # supported rmg lattices 5346 rmg_lattices = obj( 5347 orthorhombic_P = 'Orthorhombic Primitive', 5348 tetragonal_P = 'Tetragonal Primitive', 5349 hexagonal_P = 'Hexagonal Primitive', 5350 cubic_P = 'Cubic Primitive', 5351 cubic_I = 'Cubic Body Centered', 5352 cubic_F = 'Cubic Face Centered', 5353 ) 5354 5355 def rmg_lattice(self,allow_tile=False,all_results=False,ret_bravais=False,exit=False,warn=False): 5356 # output variables 5357 rmg_lattice = None 5358 tmatrix = None 5359 5360 rmg_lattices = Structure.rmg_lattices 5361 5362 # represent current bravais lattice 5363 s = Structure( 5364 units = str(self.units), 5365 axes = self.axes.copy(), 5366 elem = ['H'], 5367 pos = [[0,0,0]], 5368 ) 5369 5370 # get standard primitive cell based on bravais lattice 5371 sp = s.primitive() 5372 5373 # get current bravais lattice name 5374 d = sp.symmetry_data() 5375 bv = sp.bravais_lattice_name(d) 5376 bvp = bv.rsplit('_',1)[0]+'_P' 5377 if bv in rmg_lattices: 5378 rmg_lattice = bv 5379 #end if 5380 5381 # attempt to get a valid cell by tiling if current one is not valid 5382 if rmg_lattice is None and bvp in rmg_lattices and allow_tile: 5383 spt = Structure( 5384 units = self.units, 5385 axes = d.std_lattice, 5386 ) 5387 tmatrix,valid_by_tiling = spt.tilematrix(sp,status=True) 5388 if not valid_by_tiling: 5389 tmatrix = None 5390 else: 5391 # apply tiling matrix 5392 st = sp.copy().tile(tmatrix) 5393 # update lattice type 5394 rmg_lattice,bv = st.rmg_lattice(ret_bravais=True) 5395 #end if 5396 #end if 5397 5398 if rmg_lattice is None and (exit or warn): 5399 msg = 'Bravais lattice is not supported by the RMG code.\nCell bravais lattice: {}\nLattices supported by RMG: {}'.format(bv,list(sorted(rmg_lattices.keys()))) 5400 if exit: 5401 self.error(msg) 5402 elif warn: 5403 self.warn(msg) 5404 #end if 5405 #end if 5406 5407 if all_results: 5408 return rmg_lattice,tmatrix,s,sp,bv 5409 elif allow_tile: 5410 return rmg_lattice,tmatrix 5411 elif ret_bravais: 5412 return rmg_lattice,bv 5413 else: 5414 return rmg_lattice 5415 #end if 5416 #end def rmg_lattice 5417 5418 5419 def rmg_transform(self,allow_tile=False,allow_general=False,all_results=False): 5420 rmg_lattice,tmatrix,s,sp,bv = self.rmg_lattice( 5421 allow_tile = allow_tile, 5422 exit = not allow_general, 5423 all_results = True, 5424 ) 5425 if rmg_lattice is None and allow_general: 5426 s_trans = self.copy() 5427 rmg_inputs = obj() 5428 R = None 5429 tmatrix = None 5430 else: 5431 s_trans = self.copy() 5432 R = np.dot(np.linalg.inv(s.axes),sp.axes) 5433 s_trans.matrix_transform(R.T) 5434 if tmatrix is not None: 5435 s_trans = s_trans.tile(tmatrix) 5436 #end if 5437 if s_trans.units=='A': 5438 rmg_units = 'Angstrom' 5439 elif s_trans.units=='B': 5440 rmg_units = 'Bohr' 5441 else: 5442 self.error('Unrecognized length units in structure "{}"'.format(s_trans.units)) 5443 #end if 5444 bl_type = self.rmg_lattices[rmg_lattice] 5445 axes = s_trans.axes 5446 if bl_type=='Cubic Primitive': 5447 a = axes[0,0] 5448 b = a 5449 c = a 5450 elif bl_type=='Tetragonal Primitive': 5451 a = axes[0,0] 5452 b = a 5453 c = axes[2,2] 5454 elif bl_type=='Orthorhombic Primitive': 5455 a = axes[0,0] 5456 b = axes[1,1] 5457 c = axes[2,2] 5458 elif bl_type=='Cubic Body Centered': 5459 a = np.linalg.norm(axes[0])*2/np.sqrt(3.) 5460 b = a 5461 c = a 5462 elif bl_type=='Cubic Face Centered': 5463 a = np.linalg.norm(axes[0])*2/np.sqrt(2.) 5464 b = a 5465 c = a 5466 elif bl_type=='Hexagonal Primitive': 5467 a = np.linalg.norm(axes[0]) 5468 b = a 5469 c = np.linalg.norm(axes[2]) 5470 else: 5471 self.error('Unrecognized RMG bravais_lattice_type "{}"'.format(bl_type)) 5472 #end if 5473 rmg_inputs = obj( 5474 bravais_lattice_type = bl_type, 5475 a_length = a, 5476 b_length = b, 5477 c_length = c, 5478 length_units = rmg_units, 5479 ) 5480 #end if 5481 if not all_results: 5482 return s_trans,rmg_inputs 5483 else: 5484 return s_trans,rmg_inputs,R,tmatrix,bv 5485 #end if 5486 #end def rmg_transform 5487 5488#end class Structure 5489Structure.set_operations() 5490 5491 5492#======================# 5493# SeeK-path functions # 5494#======================# 5495 5496# installation instructions for seekpath interface 5497# 5498# installation of seekpath 5499# pip install seekpath 5500import itertools 5501from periodic_table import pt as ptable 5502try: 5503 from numpy import array_equal 5504except: 5505 array_equal = unavailable('numpy','array_equal') 5506#end try 5507try: 5508 import seekpath 5509 from seekpath import get_explicit_k_path 5510 version = seekpath.__version__ 5511 5512 try: 5513 version = [int(i) for i in version.split('.')] 5514 if len(version) < 3: 5515 raise ValueError 5516 #end if 5517 except ValueError: 5518 raise ValueError("Unable to parse version number") 5519 #end try 5520 if tuple(version) < (1, 8, 3): 5521 raise ValueError("Invalid seekpath version, need >= 1.8.4") 5522 #end if 5523 del version 5524 del seekpath 5525except: 5526 get_explicit_k_path = unavailable('seekpath','get_explicit_k_path') 5527#end try 5528 5529def _getseekpath( 5530 structure = None, 5531 with_time_reversal = False, 5532 recipe = 'hpkot', 5533 reference_distance = 0.025, 5534 threshold = 1E-7, 5535 symprec = 1E-5, 5536 angle_tolerance = 1.0, 5537 ): 5538 if not isinstance(structure, Structure): 5539 raise TypeError('structure is not of type Structure\ntype received: {0}'.format(structure.__class__.__name__)) 5540 #end if 5541 if structure.has_folded(): 5542 structure = structure.folded_structure 5543 #end if 5544 structure = structure.copy() 5545 if structure.units is not 'A': 5546 structure.change_units('A') 5547 #end if 5548 axes = structure.axes 5549 unit_pos = structure.get_scaled_positions() 5550 atomic_num = structure.get_atomic_numbers() 5551 cell = (axes,unit_pos,atomic_num) 5552 return get_explicit_k_path(cell) 5553#end def _getseekpath 5554 5555def get_conventional_cell( 5556 structure = None, 5557 symprec = 1E-5, 5558 angle_tolerance = 1.0, 5559 seekpathout = None, 5560 ): 5561 if seekpathout is None: 5562 seekpathout = _getseekpath(structure=structure, symprec = symprec, angle_tolerance=angle_tolerance) 5563 #end if 5564 axes = seekpathout['conv_lattice'] 5565 enumbers = seekpathout['conv_types'] 5566 posd = seekpathout['conv_positions'] 5567 volfac = seekpathout['volume_original_wrt_conv'] 5568 bcharge = structure.background_charge*volfac 5569 pos = dot(posd,axes) 5570 sout = structure.copy() 5571 elem = empty(len(enumbers), dtype='str') 5572 for el in ptable.elements.items(): 5573 elem[enumbers==el[1].atomic_number]=el[0] 5574 #end for 5575 if abs(bcharge-int(bcharge)) > 1E-6: 5576 raise ValueError("Invalid background charge for conventional structure") 5577 #end if 5578 return {'structure': Structure(axes=axes, elem=elem, pos=pos, background_charge = bcharge, units='A')} 5579#end def get_conventional_cell 5580 5581def get_primitive_cell( 5582 structure = None, 5583 symprec = 1E-5, 5584 angle_tolerance = 1.0, 5585 seekpathout = None, 5586 ): 5587 if seekpathout is None: 5588 seekpathout = _getseekpath(structure = structure, symprec = symprec, angle_tolerance=angle_tolerance) 5589 #end if 5590 axes = seekpathout['primitive_lattice'] 5591 enumbers = seekpathout['primitive_types'] 5592 posd = seekpathout['primitive_positions'] 5593 volfac = seekpathout['volume_original_wrt_prim'] 5594 bcharge = structure.background_charge*volfac 5595 pos = dot(posd,axes) 5596 sout = structure.copy() 5597 elem = array(enumbers, dtype='str') 5598 for el in ptable.elements.items(): 5599 elem[enumbers==el[1].atomic_number]=el[0] 5600 #end for 5601 return {'structure' : Structure(axes=axes, elem=elem, pos=pos, background_charge=bcharge, units='A'), 5602 'T' : seekpathout['primitive_transformation_matrix']} 5603#end def get_primitive_cell 5604 5605def get_kpath( 5606 structure = None, 5607 check_standard = True, 5608 with_time_reversal = False, 5609 recipe = 'hpkot', 5610 reference_distance = 0.025, 5611 threshold = 1E-7, 5612 symprec = 1E-5, 5613 angle_tolerance = 1.0, 5614 seekpathout = None, 5615 ): 5616 if seekpathout is None: 5617 seekpathout = _getseekpath(structure=structure, symprec = symprec, angle_tolerance=angle_tolerance, 5618 recipe=recipe, reference_distance=reference_distance, with_time_reversal=with_time_reversal) 5619 #end if 5620 if check_standard: 5621 structure = structure.copy() 5622 structure.change_units('A') 5623 axes = structure.axes 5624 primlat = seekpathout['primitive_lattice'] 5625 if not isclose(primlat, axes).all(): 5626 #print primlat, axes 5627 Structure.class_error('Input lattice is not the conventional lattice. If you like otherwise, set check_standard=False.') 5628 #end if 5629 #end if 5630 inverse_A_to_inverse_B = convert(1.0,'A','B') 5631 return {'explicit_kpoints_abs_inv_A' : seekpathout['explicit_kpoints_abs'], 5632 'explicit_kpoints_abs_inv_B' : seekpathout['explicit_kpoints_abs']*inverse_A_to_inverse_B, 5633 'explicit_kpoints_rel' : seekpathout['explicit_kpoints_rel'], 5634 'explicit_kpoints_labels' : seekpathout['explicit_kpoints_labels'], 5635 'path' : seekpathout['path'], 5636 'explicit_path_linearcoords': seekpathout['explicit_kpoints_linearcoord'], 5637 'point_coords' : seekpathout['point_coords']} 5638#end def get_kpath 5639 5640def get_symmetry( 5641 structure = None, 5642 symprec = 1E-5, 5643 angle_tolerance = 1.0, 5644 seekpathout = None, 5645 ): 5646 if seekpathout is None: 5647 seekpathout = _getseekpath(structure = structure, symprec = symprec, angle_tolerance=angle_tolerance) 5648 #end if 5649 sgint = seekpathout['spacegroup_international'] 5650 bravais = seekpathout['bravais_lattice'] 5651 invsym = seekpathout['has_inversion_symmetry'] 5652 sgnum = seekpathout['spacegroup_number'] 5653 5654 return {'sgint': sgint, 'bravais': bravais, 'inv_sym_exists': invsym, 'sgnum': sgnum} 5655#end def get_symmetry 5656 5657def get_structure_with_bands( 5658 cell = 0, 5659 structure = None, 5660 with_time_reversal = False, 5661 reference_distance = 0.025, 5662 threshold = 1E-7, 5663 symprec = 1E-5, 5664 angle_tolerance = 1.0, 5665 ): 5666 if cell == 0: 5667 ''' Use input structure ''' 5668 struct_band = structure.copy() 5669 elif cell == 1: 5670 ''' Use conventional structure ''' 5671 struct_band = get_conventional_cell(structure=structure, symprec=symprec, angle_tolerance=angle_tolerance)['structure'] 5672 elif cell == 2: 5673 ''' Use primitive structure ''' 5674 struct_band = get_primitive_cell(structure=structure, symprec=symprec, angle_tolerance=angle_tolerance)['structure'] 5675 else: 5676 Structure.class_error('Invalid cell type') 5677 #end if 5678 kpath = get_kpath(structure=struct_band, check_standard=False, with_time_reversal=with_time_reversal) 5679 return Structure(axes = struct_band.axes, 5680 elem = struct_band.elem, 5681 pos = struct_band.pos, 5682 background_charge = struct_band.background_charge, 5683 kpoints = kpath['explicit_kpoints_rel'], 5684 units = 'A') 5685#end def get_structure_with_bands 5686 5687# test needed 5688def get_band_tiling( 5689 structure = None, 5690 check_standard = True, 5691 use_ktol = True, 5692 kpoints_label = None, 5693 kpoints_rel = None, 5694 max_volfac = 20, 5695 min_volfac = 0, 5696 target_volfac = None, 5697 ): 5698 5699 def cube_deviation(axes): 5700 a = axes 5701 volume = abs(dot(cross(axes[0,:], axes[1,:]), axes[2,:])) 5702 dc = volume**(1./3)*sqrt(2.) 5703 d1 = abs(norm(a[0]+a[1])-dc) 5704 d2 = abs(norm(a[1]+a[2])-dc) 5705 d3 = abs(norm(a[2]+a[0])-dc) 5706 d4 = abs(norm(a[0]-a[1])-dc) 5707 d5 = abs(norm(a[1]-a[2])-dc) 5708 d6 = abs(norm(a[2]-a[0])-dc) 5709 return (d1+d2+d3+d4+d5+d6)/(6*dc) 5710 #end def cube_deviation 5711 5712 def cuboid_with_int_edges(vol): 5713 # Given a volume, return the cuboids which have integer edges 5714 divisors = [] 5715 edges = [] 5716 if isinstance(vol, int): 5717 i = 1 5718 while i<=vol: 5719 if vol%i==0: 5720 divisors.append(i) 5721 #end if 5722 i+=1 5723 #end while 5724 for i in divisors: 5725 for j in divisors: 5726 for k in divisors: 5727 if i*j*k == vol: 5728 edges.append([i,j,k]) 5729 #end if 5730 #end for 5731 #end for 5732 #end for 5733 else: 5734 self.error('Volume multiplier must be integer') 5735 #end if 5736 5737 return edges 5738 #end def cuboid_with_int_edges 5739 5740 def alphas_on_grid(alphas, divs): 5741 new_alphas = [] 5742 for alpha in alphas: 5743 abs_alpha = abs(alpha) 5744 sign_alpha = sign(alpha) 5745 new_alpha = round(abs_alpha*divs)*1./divs*sign_alpha 5746 new_alphas.append(new_alpha) 5747 #end for 5748 return new_alphas 5749 #end def alphas_on_grid 5750 5751 def find_alphas(structure, kpoints_label, kpoints_rel, check_standard): 5752 # Read wavevectors from the input and return the differences between all wavevectors (alphas) 5753 kpath = get_kpath(structure = structure, check_standard = check_standard) 5754 kpath_label = array(kpath['explicit_kpoints_labels']) 5755 kpath_rel = kpath['explicit_kpoints_rel'] 5756 kpts = dict() 5757 5758 if kpoints_label is None: 5759 kpoints_label = [] 5760 if kpoints_rel is None: 5761 Structure.class_error('Please define symbolic or crystal coordinates for kpoints. e.g. [\'GAMMA\', \'K\'] or [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]') 5762 else: 5763 for k in kpoints_rel: 5764 kindex = isclose(kpath_rel,k, atol=1e-5).all(1) 5765 if any(kindex): 5766 kpts[kpath_label[kindex][0]] = array(k) 5767 else: 5768 Structure.class_error('{0} is not found in the kpath'.format(k)) 5769 #end if 5770 #end for 5771 #end if 5772 else: 5773 if kpoints_rel is not None: 5774 Structure.class_error('Both symbolic and crystal k-points are defined.') 5775 else: 5776 kpoints_rel = [] 5777 num_kpoints = 0 5778 for k in kpoints_label: 5779 kindex = k == kpath_label 5780 if any(kindex): 5781 if k == '' or k == None: 5782 k = '{0}'.format(num_kpoints) 5783 #end if 5784 kpts[k] = array(kpath_rel[kindex][0]) 5785 else: 5786 Structure.class_error('{0} is not found in the kpath'.format(k)) 5787 #end if 5788 #end for 5789 #end if 5790 #end if 5791 alphas = array([x[0] - x[1] for x in itertools.combinations(kpts.values(),2)]) #Combinations of k_1 - k_2 in kpts list 5792 kpt0 = list(kpts.values())[0] 5793 return alphas, kpt0 5794 #end def find_alphas 5795 5796 def find_vars(alphas,min_volfac, max_volfac, target_volfac, use_ktol): 5797 ''' 5798 Find variables to generate possible smallest matrices in the Upper Triangular Hermite Normal Form, from PHYSICAL REVIEW B 92, 184301 (2015) 5799 For target or min/max volumes, it returns vol_mul, which is the smallest volume multiplier to be used on the volfac here 5800 ''' 5801 if use_ktol: 5802 ktol = 0.25/max_volfac 5803 else: 5804 ktol = 0.0 5805 #end if 5806 5807 if target_volfac is not None: 5808 if min_volfac is None and max_volfac is None: 5809 min_volfac = target_volfac 5810 max_volfac = target_volfac 5811 else: 5812 print("target_volfac and {min_volfac, max_volfac} cannot be defined together!") 5813 exit() 5814 #end if 5815 #end if 5816 5817 cur_volfac = 1e6 5818 vars = [] 5819 for mvol in range(max_volfac, 0, -1): 5820 cuboids = cuboid_with_int_edges(mvol) 5821 for c in cuboids: 5822 a_1, a_2, a_3 = c 5823 new_alphas = alphas_on_grid(alphas, c) 5824 5825 5826 rec_grid_voxel = array([1./a_1,1./a_2,1./a_3]) # reciprocal grid voxel size 5827 rem = [] 5828 for alpha in alphas: 5829 rem.append(mod(abs(alpha), rec_grid_voxel)) 5830 #end for 5831 if all(isclose(rem,0.0,atol=ktol)): 5832 n1 = a_1 5833 n2 = a_2 5834 n3 = a_3 5835 g12 = np.gcd.reduce([n1,n2]) 5836 g13 = np.gcd.reduce([n1,n3]) 5837 g23 = np.gcd.reduce([n2,n3]) 5838 g123 = np.gcd.reduce([n1, n2, n3]) 5839 volfac = n1*n2*n3*g123//(g12*g13*g23) 5840 if volfac < cur_volfac: #min_volfac <= volfac and and volfac <= max_volfac: 5841 vars = [[n1, n2, n3, g12, g13, g23, g123]] 5842 cur_volfac = volfac 5843 elif volfac == cur_volfac: 5844 vars.append([n1, n2, n3, g12, g13, g23, g123]) 5845 #end if 5846 #end if 5847 #end for 5848 #end for 5849 if vars == []: 5850 print('Change ktol') 5851 exit() 5852 else: 5853 can_be_found = False 5854 vol_mul = 1 5855 while not can_be_found: 5856 if volfac*vol_mul <= max_volfac and volfac*vol_mul >= min_volfac: 5857 can_be_found = True 5858 elif volfac*vol_mul > max_volfac: 5859 print('Increase max_volfac or target_volfac!') 5860 exit() 5861 else: 5862 vol_mul+=1 5863 #end if 5864 #end while 5865 return vars, vol_mul 5866 #end if 5867 5868 #end def find_vars 5869 5870 def find_mats(mat_vars, alphas): 5871 ''' 5872 Given the variables (v), return the list of all upper triangular matrices as in PHYSICAL REVIEW B 92, 184301 (2015) 5873 ''' 5874 mats = [] 5875 for v in mat_vars: 5876 n1, n2, n3, g12, g13, g23, g123 = v 5877 #New alphas exactly on the voxels thanks to ktol 5878 for p in range(0, g23): 5879 for q in range(0, g12//g123): 5880 for r in range(g13*g23//g123): 5881 temp_mat = [ 5882 [g123*n1//(g12*g13), q*g123*n2//(g12*g23), r*g123*n3//(g13*g23)], 5883 [0, n2//g23, p*n3//g23], 5884 [0,0,n3]] 5885 comm = [] 5886 div = array([n1, n2, n3]) 5887 new_alphas = alphas_on_grid(alphas, div) 5888 for new_alpha in new_alphas: 5889 if (isclose(abs(dot(temp_mat,new_alpha))%1.0, 0, atol = 1e-6)).all(): 5890 comm.append(True) # new_alpha is commensurate with tmat 5891 else: 5892 comm.append(False) 5893 #end if 5894 #end for 5895 if all(comm) and temp_mat not in mats: # if all new_alphas are commensurate with tmat 5896 mats.append(temp_mat) 5897 #end if 5898 #end for 5899 #end for 5900 #end for 5901 #end for 5902 5903 return mats 5904 5905 #end def find_mats 5906 5907 def find_cubic_mat(mats, structure, mat_vol_mul): 5908 final_axes = [] 5909 final_t = [] 5910 final_cubicity = 1e6 5911 mats = array(mats) 5912 for m in mats: 5913 axes = structure.axes.copy() 5914 s = structure.copy() 5915 [m_t, r] = optimal_tilematrix(s.tile(m), volfac=mat_vol_mul) 5916 m_axes = dot(dot(m_t, m), axes) 5917 m_cubicity = cube_deviation(m_axes) 5918 if m_cubicity < final_cubicity: 5919 final_axes = m_axes 5920 final_cubicity = m_cubicity 5921 final_t = dot(m_t, m) 5922 #end if 5923 #end for 5924 return final_t.tolist() 5925 #end def find_cubic_mat 5926 5927 def find_shift(final_mat, structure, kpt0): 5928 return None 5929 #end def find_cubic_mat 5930 5931 alphas, kpt0 = find_alphas(structure,kpoints_label,kpoints_rel, check_standard) # Wavevector differences 5932 mat_vars, mat_vol_mul = find_vars(alphas,min_volfac,max_volfac,target_volfac,use_ktol) # Variables to construct upper triangular matrices 5933 mats = find_mats(mat_vars,alphas) # List of upper triangular matrices that are commensurate with alphas 5934 final_mat = find_cubic_mat(mats, structure, mat_vol_mul) # Matrix leading to a lattice with highest cubicity, optimized using elementary operations 5935 shift = find_shift(final_mat, structure, kpts0) # Find the grid shift 5936 o = obj() 5937 o.mat = final_mat 5938 o.shift = shift 5939 o.det = det(final_mat) 5940 return o 5941#end def get_band_tiling 5942 5943def get_seekpath_full( 5944 structure = None, 5945 seekpathout = None, 5946 conventional = False, 5947 primitive = False, 5948 **kwargs 5949 ): 5950 if seekpathout is None: 5951 seekpathout = _getseekpath(structure,**kwargs) 5952 #end if 5953 res = obj(seekpathout) 5954 for k,v in res.items(): 5955 if isinstance(v,dict): 5956 res[k] = obj(v) 5957 #end if 5958 #end for 5959 if conventional: 5960 conv = get_conventional_cell(structure,seekpathout=seekpathout) 5961 res.conventional = conv['structure'] 5962 #end if 5963 if primitive: 5964 prim = get_primitive_cell(structure,seekpathout=seekpathout) 5965 res.primitive = prim['structure'] 5966 res.prim_tmatrix = prim['T'] 5967 #end if 5968 return res 5969#end def get_seekpath_full 5970 5971skp = obj( 5972 _getseekpath = _getseekpath, 5973 get_conventional_cell = get_conventional_cell, 5974 get_primitive_cell = get_primitive_cell, 5975 get_kpath = get_kpath, 5976 get_symmetry = get_symmetry, 5977 get_structure_with_bands = get_structure_with_bands, 5978 get_band_tiling = get_band_tiling, 5979 ) 5980 5981#==========================# 5982# end SeeK-path functions # 5983#==========================# 5984 5985 5986 5987def interpolate_structures(struct1,struct2=None,images=None,min_image=True,recenter=True,match_com=False,repackage=False,chained=False): 5988 if images is None: 5989 Structure.class_error('images must be provided','interpolate_structures') 5990 #end if 5991 5992 # if a list of structures is provided, 5993 # interpolate between pairs in the chain of structures 5994 if isinstance(struct1,(list,tuple)): 5995 structures_in = struct1 5996 structures = [] 5997 for n in range(len(structures_in)-1): 5998 struct1 = structures_in[n] 5999 struct2 = structures_in[n+1] 6000 structs = interpolate_structures(struct1,struct2,images,min_image,recenter,match_com,repackage,chained=True) 6001 if n==0: 6002 structures.append(structs[0]) 6003 #end if 6004 structures.extend(structs[1:-1]) 6005 if n==len(structures_in)-2: 6006 structures.append(structs[-1]) 6007 #end if 6008 #end for 6009 return structures 6010 #end if 6011 6012 # handle PhysicalSystem objects indirectly 6013 system1 = None 6014 system2 = None 6015 if not isinstance(struct1,Structure): 6016 system1 = struct1.copy() 6017 system1.remove_folded() 6018 struct1 = system1.structure 6019 #end if 6020 if not isinstance(struct2,Structure): 6021 system2 = struct2.copy() 6022 system2.remove_folded() 6023 struct2 = system2.structure 6024 #end if 6025 6026 # perform the interpolation 6027 structures = struct1.interpolate(struct2,images,min_image,recenter,match_com) 6028 6029 # repackage into physical system objects if requested 6030 if repackage: 6031 if system1!=None: 6032 system = system1 6033 elif system2!=None: 6034 system = system2 6035 else: 6036 Structure.class_error('cannot repackage into physical systems since no system object was provided in place of a structure','interpolate_structures') 6037 #end if 6038 systems = [] 6039 for s in structures: 6040 ps = system.copy() 6041 ps.structure = s 6042 systems.append(ps) 6043 #end for 6044 result = systems 6045 else: 6046 result = structures 6047 #end if 6048 6049 return result 6050#end def interpolate_structures 6051 6052 6053# test needed 6054def structure_animation(filepath,structures,tiling=None): 6055 path,file = os.path.split(filepath) 6056 if not file.endswith('xyz'): 6057 Structure.class_error('only xyz files are supported for now','structure_animation') 6058 #end if 6059 anim = '' 6060 for s in structures: 6061 if tiling is None: 6062 anim += s.write_xyz() 6063 else: 6064 anim += s.tile(tiling).write_xyz() 6065 #end if 6066 #end for 6067 open(filepath,'w').write(anim) 6068#end def structure_animation 6069 6070 6071 6072class DefectStructure(Structure): 6073 def __init__(self,*args,**kwargs): 6074 if len(args)>0 and isinstance(args[0],Structure): 6075 self.transfer_from(args[0],copy=True) 6076 else: 6077 Structure.__init__(self,*args,**kwargs) 6078 #end if 6079 #end def __init__ 6080 6081 6082 def defect_from_bond_compression(self,compression_cutoff,bond_eq,neighbors): 6083 bind,bcent,blens = self.bonds(neighbors) 6084 ind = bind[ abs(blens/bond_eq - 1.) > compression_cutoff ] 6085 idefect = array(list(set(ind.ravel()))) 6086 defect = self.carve(idefect) 6087 return defect 6088 #end def defect_from_bond_compression 6089 6090 6091 def defect_from_displacement(self,displacement_cutoff,reference): 6092 displacement = self.scalar_displacement(reference) 6093 idefect = displacement > displacement_cutoff 6094 defect = self.carve(idefect) 6095 return defect 6096 #end def defect_from_displacement 6097 6098 6099 def compare(self,dist_cutoff,d1,d2=None): 6100 if d2==None: 6101 d2 = d1 6102 d1 = self 6103 #end if 6104 res = Sobj() 6105 natoms1 = len(d1.pos) 6106 natoms2 = len(d2.pos) 6107 if natoms1<natoms2: 6108 dsmall,dlarge = d1,d2 6109 else: 6110 dsmall,dlarge = d2,d1 6111 #end if 6112 nn = nearest_neighbors(1,dlarge,dsmall) 6113 dist = dsmall.distances(dlarge[nn.ravel()]) 6114 dmatch = dist<dist_cutoff 6115 ismall = array(list(range(len(dsmall.pos)))) 6116 ismall = ismall[dmatch] 6117 ilarge = nn[ismall] 6118 if natoms1<natoms2: 6119 i1,i2 = ismall,ilarge 6120 else: 6121 i2,i1 = ismall,ilarge 6122 #end if 6123 natoms_match = dmatch.sum() 6124 res.all_match = natoms1==natoms2 and natoms1==natoms_match 6125 res.natoms_match = natoms_match 6126 res.imatch1 = i1 6127 res.imatch2 = i2 6128 return res 6129 #end def compare 6130#end class DefectStructure 6131 6132 6133 6134class Crystal(Structure): 6135 lattice_constants = obj( 6136 triclinic = ['a','b','c','alpha','beta','gamma'], 6137 monoclinic = ['a','b','c','beta'], 6138 orthorhombic = ['a','b','c'], 6139 tetragonal = ['a','c'], 6140 hexagonal = ['a','c'], 6141 cubic = ['a'], 6142 rhombohedral = ['a','alpha'] 6143 ) 6144 6145 lattices = list(lattice_constants.keys()) 6146 6147 centering_types = obj( 6148 primitive = 'P', 6149 base_centered = ('A','B','C'), 6150 face_centered = 'F', 6151 body_centered = 'I', 6152 rhombohedral_centered = 'R' 6153 ) 6154 6155 lattice_centerings = obj( 6156 triclinic = ['P'], 6157 monoclinic = ['P','A','B','C'], 6158 orthorhombic = ['P','C','I','F'], 6159 tetragonal = ['P','I'], 6160 hexagonal = ['P','R'], 6161 cubic = ['P','I','F'], 6162 rhombohedral = ['P'] 6163 ) 6164 6165 centerings = obj( 6166 P = [], 6167 A = [[0,.5,.5]], 6168 B = [[.5,0,.5]], 6169 C = [[.5,.5,0]], 6170 F = [[0,.5,.5],[.5,0,.5],[.5,.5,0]], 6171 I = [[.5,.5,.5]], 6172 R = [[2./3, 1./3, 1./3],[1./3, 2./3, 2./3]] 6173 ) 6174 6175 cell_types = set(['primitive','conventional']) 6176 6177 cell_aliases = obj( 6178 prim = 'primitive', 6179 conv = 'conventional' 6180 ) 6181 cell_classes = obj( 6182 sc = 'cubic', 6183 bcc = 'cubic', 6184 fcc = 'cubic', 6185 hex = 'hexagonal' 6186 ) 6187 for lattice in lattices: 6188 cell_classes[lattice]=lattice 6189 #end for 6190 6191 6192 #helpful websites for structures 6193 # wikipedia.org 6194 # webelements.com 6195 # webmineral.com 6196 # springermaterials.com 6197 6198 6199 known_crystals = { 6200 ('diamond','fcc'):obj( 6201 lattice = 'cubic', 6202 cell = 'primitive', 6203 centering = 'F', 6204 constants = 3.57, 6205 units = 'A', 6206 atoms = 'C', 6207 basis = [[0,0,0],[.25,.25,.25]] 6208 ), 6209 ('diamond','sc'):obj( 6210 lattice = 'cubic', 6211 cell = 'conventional', 6212 centering = 'F', 6213 constants = 3.57, 6214 units = 'A', 6215 atoms = 'C', 6216 basis = [[0,0,0],[.25,.25,.25]] 6217 ), 6218 ('diamond','prim'):obj( 6219 lattice = 'cubic', 6220 cell = 'primitive', 6221 centering = 'F', 6222 constants = 3.57, 6223 units = 'A', 6224 atoms = 'C', 6225 basis = [[0,0,0],[.25,.25,.25]] 6226 ), 6227 ('diamond','conv'):obj( 6228 lattice = 'cubic', 6229 cell = 'conventional', 6230 centering = 'F', 6231 constants = 3.57, 6232 units = 'A', 6233 atoms = 'C', 6234 basis = [[0,0,0],[.25,.25,.25]] 6235 ), 6236 ('wurtzite','prim'):obj( 6237 lattice = 'hexagonal', 6238 cell = 'primitive', 6239 centering = 'P', 6240 constants = (3.35,5.22), 6241 units = 'A', 6242 #atoms = ('Zn','O'), 6243 #basis = [[1./3, 2./3, 3./8],[1./3, 2./3, 0]] 6244 atoms = ('Zn','O','Zn','O'), 6245 basis = [[0,0,5./8],[0,0,0],[2./3,1./3,1./8],[2./3,1./3,1./2]] 6246 ), 6247 ('ZnO','prim'):obj( 6248 lattice = 'wurtzite', 6249 cell = 'prim', 6250 constants = (3.35,5.22), 6251 units = 'A', 6252 atoms = ('Zn','O','Zn','O') 6253 ), 6254 ('NaCl','prim'):obj( 6255 lattice = 'cubic', 6256 cell = 'primitive', 6257 centering = 'F', 6258 constants = 5.64, 6259 units = 'A', 6260 atoms = ('Na','Cl'), 6261 basis = [[0,0,0],[.5,0,0]], 6262 basis_vectors = 'conventional' 6263 ), 6264 ('rocksalt','prim'):obj( 6265 lattice = 'cubic', 6266 cell = 'primitive', 6267 centering = 'F', 6268 constants = 5.64, 6269 units = 'A', 6270 atoms = ('Na','Cl'), 6271 basis = [[0,0,0],[.5,0,0]], 6272 basis_vectors = 'conventional' 6273 ), 6274 ('copper','prim'):obj( 6275 lattice = 'cubic', 6276 cell = 'primitive', 6277 centering = 'F', 6278 constants = 3.615, 6279 units = 'A', 6280 atoms = 'Cu' 6281 ), 6282 ('calcium','prim'):obj( 6283 lattice = 'cubic', 6284 cell = 'primitive', 6285 centering = 'F', 6286 constants = 5.588, 6287 units = 'A', 6288 atoms = 'Ca' 6289 ), 6290 # http://www.webelements.com/oxygen/crystal_structure.html 6291 # Phys Rev 160 694 6292 ('oxygen','prim'):obj( 6293 lattice = 'monoclinic', 6294 cell = 'primitive', 6295 centering = 'C', 6296 constants = (5.403,3.429,5.086,132.53), 6297 units = 'A', 6298 angular_units = 'degrees', 6299 atoms = ('O','O'), 6300 basis = [[0,0,1.15/2],[0,0,-1.15/2]], 6301 basis_vectors = identity(3) 6302 ), 6303 # http://en.wikipedia.org/wiki/Calcium_oxide 6304 # http://www.springermaterials.com/docs/info/10681719_224.html 6305 ('CaO','prim'):obj( 6306 lattice = 'NaCl', 6307 cell = 'prim', 6308 constants = 4.81, 6309 atoms = ('Ca','O') 6310 ), 6311 ('CaO','conv'):obj( 6312 lattice = 'NaCl', 6313 cell = 'conv', 6314 constants = 4.81, 6315 atoms = ('Ca','O') 6316 ), 6317 # http://en.wikipedia.org/wiki/Copper%28II%29_oxide 6318 # http://iopscience.iop.org/0953-8984/3/28/001/ 6319 # http://www.webelements.com/compounds/copper/copper_oxide.html 6320 # http://www.webmineral.com/data/Tenorite.shtml 6321 ('CuO','prim'):obj( 6322 lattice = 'monoclinic', 6323 cell = 'primitive', 6324 centering = 'C', 6325 constants = (4.683,3.422,5.128,99.54), 6326 units = 'A', 6327 angular_units = 'degrees', 6328 atoms = ('Cu','O','Cu','O'), 6329 basis = [[.25,.25,0],[0,.418,.25], 6330 [.25,.75,.5],[.5,.5-.418,.75]], 6331 basis_vectors = 'conventional' 6332 ), 6333 ('Ca2CuO3','prim'):obj(# kateryna foyevtsova 6334 lattice = 'orthorhombic', 6335 cell = 'primitive', 6336 centering = 'I', 6337 constants = (3.77,3.25,12.23), 6338 units = 'A', 6339 atoms = ('Cu','O','O','O','Ca','Ca'), 6340 basis = [[ 0, 0, 0 ], 6341 [ .50, 0, 0 ], 6342 [ 0, 0, .16026165], 6343 [ 0, 0, .83973835], 6344 [ 0, 0, .35077678], 6345 [ 0, 0, .64922322]], 6346 basis_vectors = 'conventional' 6347 ), 6348 ('La2CuO4','prim'):obj( #tetragonal structure 6349 lattice = 'tetragonal', 6350 cell = 'primitive', 6351 centering = 'I', 6352 constants = (3.809,13.169), 6353 units = 'A', 6354 atoms = ('Cu','O','O','O','O','La','La'), 6355 basis = [[ 0, 0, 0], 6356 [ .5, 0, 0], 6357 [ 0, .5, 0], 6358 [ 0, 0, .182], 6359 [ 0, 0, -.182], 6360 [ 0, 0, .362], 6361 [ 0, 0, -.362]] 6362 ), 6363 ('Cl2Ca2CuO2','prim'):obj( 6364 lattice = 'tetragonal', 6365 cell = 'primitive', 6366 centering = 'I', 6367 constants = (3.869,15.05), 6368 units = 'A', 6369 atoms = ('Cu','O','O','Ca','Ca','Cl','Cl'), 6370 basis = [[ 0, 0, 0 ], 6371 [ .5, 0, 0 ], 6372 [ 0, .5, 0 ], 6373 [ .5, .5, .104], 6374 [ 0, 0, .396], 6375 [ 0, 0, .183], 6376 [ .5, .5, .317]], 6377 basis_vectors = 'conventional' 6378 ), 6379 ('Cl2Ca2CuO2','afm'):obj( 6380 lattice = 'tetragonal', 6381 cell = 'conventional', 6382 centering = 'P', 6383 axes = [[.5,-.5,0],[.5,.5,0],[0,0,1]], 6384 constants = (2*3.869,15.05), 6385 units = 'A', 6386 atoms = 4*['Cu','O','O','Ca','Ca','Cl','Cl'], 6387 basis = [[ 0, 0, 0 ], #Cu 6388 [ .25, 0, 0 ], 6389 [ 0, .25, 0 ], 6390 [ .25, .25, .104], 6391 [ 0, 0, .396], 6392 [ 0, 0, .183], 6393 [ .25, .25, .317], 6394 [ .25, .25, .5 ], #Cu 6395 [ .5, .25, .5 ], 6396 [ .25, .5, .5 ], 6397 [ .5, .5, .604], 6398 [ .25, .25, .896], 6399 [ .25, .25, .683], 6400 [ .5, .5, .817], 6401 [ .5, 0, 0 ], #Cu2 6402 [ .75, 0, 0 ], 6403 [ .5, .25, 0 ], 6404 [ .75, .25, .104], 6405 [ .5, 0, .396], 6406 [ .5, 0, .183], 6407 [ .75, .25, .317], 6408 [ .75, .25, .5 ], #Cu2 6409 [ 0, .25, .5 ], 6410 [ .75, .5, .5 ], 6411 [ 0, .5, .604], 6412 [ .75, .25, .896], 6413 [ .75, .25, .683], 6414 [ 0, .5, .817]], 6415 basis_vectors = 'conventional' 6416 ), 6417 ('CuO2_plane','prim'):obj( 6418 lattice = 'tetragonal', 6419 cell = 'primitive', 6420 centering = 'P', 6421 constants = (3.809,13.169), 6422 units = 'A', 6423 atoms = ('Cu','O','O'), 6424 basis = [[ 0, 0, 0], 6425 [ .5, 0, 0], 6426 [ 0, .5, 0]] 6427 ), 6428 ('graphite_aa','hex'):obj( 6429 axes = [[1./2,-sqrt(3.)/2,0],[1./2,sqrt(3.)/2,0],[0,0,1]], 6430 constants = (2.462,3.525), 6431 units = 'A', 6432 atoms = ('C','C'), 6433 basis = [[0,0,0],[2./3,1./3,0]] 6434 ), 6435 ('graphite_ab','hex'):obj( 6436 axes = [[1./2,-sqrt(3.)/2,0],[1./2,sqrt(3.)/2,0],[0,0,1]], 6437 constants = (2.462,3.525), 6438 units = 'A', 6439 cscale = (1,2), 6440 atoms = ('C','C','C','C'), 6441 basis = [[0,0,0],[2./3,1./3,0],[0,0,1./2],[1./3,2./3,1./2]] 6442 ), 6443 ('graphene','prim'):obj( 6444 lattice = 'hexagonal', 6445 cell = 'primitive', 6446 centering = 'P', 6447 constants = (2.462,15.0), 6448 units = 'A', 6449 atoms = ('C','C'), 6450 basis = [[0,0,0],[2./3,1./3,0]] 6451 ), 6452 ('graphene','rect'):obj( 6453 lattice = 'orthorhombic', 6454 cell = 'conventional', 6455 centering = 'C', 6456 constants = (2.462,sqrt(3.)*2.462,15.0), 6457 units = 'A', 6458 atoms = ('C','C'), 6459 basis = [[0,0,0],[1./2,1./6,0]] 6460 ) 6461 } 6462 6463 kc_keys = list(known_crystals.keys()) 6464 for (name,cell) in kc_keys: 6465 desc = known_crystals[name,cell] 6466 if cell=='prim' and not (name,'conv') in known_crystals: 6467 cdesc = desc.copy() 6468 if cdesc.cell=='primitive': 6469 cdesc.cell = 'conventional' 6470 known_crystals[name,'conv'] = cdesc 6471 elif cdesc.cell=='prim': 6472 cdesc.cell = 'conv' 6473 known_crystals[name,'conv'] = cdesc 6474 #end if 6475 #end if 6476 #end if 6477 del kc_keys 6478 6479 6480 def __init__(self, 6481 lattice = None, 6482 cell = None, 6483 centering = None, 6484 constants = None, 6485 atoms = None, 6486 basis = None, 6487 basis_vectors = None, 6488 tiling = None, 6489 cscale = None, 6490 axes = None, 6491 units = None, 6492 angular_units = 'degrees', 6493 kpoints = None, 6494 kgrid = None, 6495 mag = None, 6496 frozen = None, 6497 magnetization = None, 6498 kshift = (0,0,0), 6499 permute = None, 6500 operations = None, 6501 elem = None, 6502 pos = None, 6503 use_prim = None, 6504 add_kpath = False, 6505 symm_kgrid = False, 6506 ): 6507 6508 if lattice is None and cell is None and atoms is None and units is None: 6509 return 6510 #end if 6511 6512 gi = obj( 6513 lattice = lattice , 6514 cell = cell , 6515 centering = centering , 6516 constants = constants , 6517 atoms = atoms , 6518 basis = basis , 6519 basis_vectors = basis_vectors , 6520 tiling = tiling , 6521 cscale = cscale , 6522 axes = axes , 6523 units = units , 6524 angular_units = angular_units , 6525 frozen = frozen , 6526 mag = mag , 6527 magnetization = magnetization , 6528 kpoints = kpoints , 6529 kgrid = kgrid , 6530 kshift = kshift , 6531 permute = permute , 6532 operations = operations , 6533 elem = elem , 6534 pos = pos , 6535 use_prim = use_prim , 6536 add_kpath = add_kpath , 6537 symm_kgrid = symm_kgrid , 6538 ) 6539 generation_info = gi.copy() 6540 6541 lattice_in = lattice 6542 if isinstance(lattice,str): 6543 lattice=lattice.lower() 6544 #end if 6545 if isinstance(cell,str): 6546 cell=cell.lower() 6547 #end if 6548 6549 known_crystal = False 6550 if (lattice_in,cell) in self.known_crystals: 6551 known_crystal = True 6552 lattice_info = self.known_crystals[lattice_in,cell].copy() 6553 elif (lattice,cell) in self.known_crystals: 6554 known_crystal = True 6555 lattice_info = self.known_crystals[lattice,cell].copy() 6556 #end if 6557 6558 if known_crystal: 6559 while 'lattice' in lattice_info and 'cell' in lattice_info and (lattice_info.lattice,lattice_info.cell) in self.known_crystals: 6560 li_old = lattice_info 6561 lattice_info = self.known_crystals[li_old.lattice,li_old.cell].copy() 6562 del li_old.lattice 6563 del li_old.cell 6564 lattice_info.transfer_from(li_old,copy=False) 6565 #end while 6566 if 'cell' in lattice_info: 6567 cell = lattice_info.cell 6568 elif cell in self.cell_aliases: 6569 cell = self.cell_aliases[cell] 6570 elif cell in self.cell_classes: 6571 lattice = self.cell_classes[cell] 6572 else: 6573 self.error('cell shape '+cell+' is not recognized\n the variable cell_classes or cell_aliases must be updated to include '+cell) 6574 #end if 6575 if 'lattice' in lattice_info: 6576 lattice = lattice_info.lattice 6577 #end if 6578 if 'angular_units' in lattice_info: 6579 angular_units = lattice_info.angular_units 6580 #end if 6581 inputs = obj( 6582 centering = centering, 6583 constants = constants, 6584 atoms = atoms, 6585 basis = basis, 6586 basis_vectors = basis_vectors, 6587 tiling = tiling, 6588 cscale = cscale, 6589 axes = axes, 6590 units = units 6591 ) 6592 for var,val in inputs.items(): 6593 if val is None and var in lattice_info: 6594 inputs[var] = lattice_info[var] 6595 #end if 6596 #end for 6597 centering,constants,atoms,basis,basis_vectors,tiling,cscale,axes,units=inputs.list('centering','constants','atoms','basis','basis_vectors','tiling','cscale','axes','units') 6598 #end if 6599 6600 if constants is None: 6601 self.error('the variable constants must be provided') 6602 #end if 6603 if atoms is None: 6604 self.error('the variable atoms must be provided') 6605 #end if 6606 6607 if lattice not in self.lattices: 6608 self.error('lattice type '+str(lattice)+' is not recognized\n valid lattice types are: '+str(list(self.lattices))) 6609 #end if 6610 if cell=='conventional': 6611 if centering is None: 6612 self.error('centering must be provided for a conventional cell\n options for a '+lattice+' lattice are: '+str(self.lattice_centerings[lattice])) 6613 elif centering not in self.centerings: 6614 self.error('centering type '+str(centering)+' is not recognized\n options for a '+lattice+' lattice are: '+str(self.lattice_centerings[lattice])) 6615 #end if 6616 #end if 6617 if isinstance(constants,int) or isinstance(constants,float): 6618 constants=[constants] 6619 #end if 6620 if len(constants)!=len(self.lattice_constants[lattice]): 6621 self.error('the '+lattice+' lattice depends on the constants '+str(self.lattice_constants[lattice])+'\n you provided '+str(len(constants))+': '+str(constants)) 6622 #end if 6623 if isinstance(atoms,str): 6624 if basis!=None: 6625 atoms = len(basis)*[atoms] 6626 else: 6627 atoms=[atoms] 6628 #end if 6629 #end if 6630 if basis is None: 6631 if len(atoms)==1: 6632 basis = [(0,0,0)] 6633 else: 6634 self.error('must provide as many basis coordinates as basis atoms\n atoms provided: '+str(atoms)+'\n basis provided: '+str(basis)) 6635 #end if 6636 #end if 6637 if basis_vectors is not None and not isinstance(basis_vectors,str) and len(basis_vectors)!=3: 6638 self.error('3 basis vectors must be given, you provided '+str(len(basis))+':\n '+str(basis_vectors)) 6639 #end if 6640 6641 if tiling is None: 6642 tiling = (1,1,1) 6643 #end if 6644 if cscale is None: 6645 cscale = len(constants)*[1] 6646 #end if 6647 if len(cscale)!=len(constants): 6648 self.error('cscale and constants must be the same length') 6649 #end if 6650 basis = array(basis) 6651 tiling = array(tiling,dtype=int) 6652 cscale = array(cscale) 6653 constants = cscale*array(constants) 6654 6655 a,b,c,alpha,beta,gamma = None,None,None,None,None,None 6656 if angular_units=='radians': 6657 pi_1o2 = pi/2 6658 pi_2o3 = 2*pi/3 6659 elif angular_units=='degrees': 6660 pi_1o2 = 90. 6661 pi_2o3 = 120. 6662 else: 6663 self.error('angular units must be radians or degrees\n you provided '+str(angular_units)) 6664 #end if 6665 if lattice=='triclinic': 6666 a,b,c,alpha,beta,gamma = constants 6667 elif lattice=='monoclinic': 6668 a,b,c,beta = constants 6669 alpha = gamma = pi_1o2 6670 elif lattice=='orthorhombic': 6671 a,b,c = constants 6672 alpha=beta=gamma=pi_1o2 6673 elif lattice=='tetragonal': 6674 a,c = constants 6675 b=a 6676 alpha=beta=gamma=pi_1o2 6677 elif lattice=='hexagonal': 6678 a,c = constants 6679 b=a 6680 alpha=beta=pi_1o2 6681 gamma=pi_2o3 6682 elif lattice=='cubic': 6683 a=constants[0] 6684 b=c=a 6685 alpha=beta=gamma=pi_1o2 6686 elif lattice=='rhombohedral': 6687 a,alpha = constants 6688 b=c=a 6689 beta=gamma=alpha 6690 #end if 6691 if angular_units=='degrees': 6692 alpha *= pi/180 6693 beta *= pi/180 6694 gamma *= pi/180 6695 #end if 6696 6697 points = [[0,0,0]] 6698 #get the conventional axes 6699 sa,ca = sin(alpha),cos(alpha) 6700 sb,cb = sin(beta) ,cos(beta) 6701 sg,cg = sin(gamma),cos(gamma) 6702 y = (ca-cg*cb)/sg 6703 a1c = a*array([1,0,0]) 6704 a2c = b*array([cg,sg,0]) 6705 a3c = c*array([cb,y,sqrt(sb**2-y**2)]) 6706 #a1c = array([a,0,0]) 6707 #a2c = array([b*cos(gamma),b*sin(gamma),0]) 6708 #a3c = array([c*cos(beta),c*cos(alpha)*sin(beta),c*sin(alpha)*sin(beta)]) 6709 axes_conv = array([a1c,a2c,a3c]).copy() 6710 6711 if axes is None: 6712 if cell not in self.cell_types: 6713 self.error('cell must be primitive or conventional\n You provided: '+str(cell)) 6714 #end if 6715 if cell=='primitive' and centering=='P': 6716 cell='conventional' 6717 #end if 6718 #get the primitive axes 6719 if centering=='P': 6720 a1 = a1c 6721 a2 = a2c 6722 a3 = a3c 6723 elif centering=='A': 6724 a1 = a1c 6725 a2 = (a2c+a3c)/2 6726 a3 = (-a2c+a3c)/2 6727 elif centering=='B': 6728 a1 = (a1c+a3c)/2 6729 a2 = a2c 6730 a3 = (-a1c+a3c)/2 6731 elif centering=='C': 6732 a1 = (a1c-a2c)/2 6733 a2 = (a1c+a2c)/2 6734 a3 = a3c 6735 elif centering=='I': 6736 a1=[ a/2, b/2,-c/2] 6737 a2=[-a/2, b/2, c/2] 6738 a3=[ a/2,-b/2, c/2] 6739 elif centering=='F': 6740 a1=[a/2, b/2, 0] 6741 a2=[ 0, b/2, c/2] 6742 a3=[a/2, 0, c/2] 6743 elif centering=='R': 6744 a1=[ a, 0, 0] 6745 a2=[ a/2, a*sqrt(3.)/2, 0] 6746 a3=[-a/6, a/(2*sqrt(3.)), c/3] 6747 else: 6748 self.error('the variable centering must be specified\n valid options are: P,A,B,C,I,F,R') 6749 #end if 6750 axes_prim = array([a1,a2,a3]) 6751 if cell=='primitive': 6752 axes = axes_prim 6753 elif cell=='conventional': 6754 axes = axes_conv 6755 points.extend(self.centerings[centering]) 6756 #end if 6757 elif known_crystal: 6758 axes = dot(diag([a,b,c]),array(axes)) 6759 #end if 6760 points = array(points,dtype=float) 6761 6762 elem = [] 6763 pos = [] 6764 if basis_vectors is None: 6765 basis_vectors = axes 6766 elif basis_vectors is 'primitive': 6767 basis_vectors = axes_prim 6768 elif basis_vectors is 'conventional': 6769 basis_vectors = axes_conv 6770 #end if 6771 nbasis = len(atoms) 6772 for point in points: 6773 for i in range(nbasis): 6774 atom = atoms[i] 6775 bpoint = basis[i] 6776 p = dot(point,axes) + dot(bpoint,basis_vectors) 6777 elem.append(atom) 6778 pos.append(p) 6779 #end for 6780 #end for 6781 pos = array(pos) 6782 6783 self.set( 6784 constants = array([a,b,c]), 6785 angles = array([alpha,beta,gamma]), 6786 generation_info = generation_info 6787 ) 6788 6789 Structure.__init__( 6790 self, 6791 axes = axes, 6792 scale = a, 6793 elem = elem, 6794 pos = pos, 6795 center = axes.sum(0)/2, 6796 units = units, 6797 frozen = frozen, 6798 mag = mag, 6799 magnetization = magnetization, 6800 tiling = tiling, 6801 kpoints = kpoints, 6802 kgrid = kgrid, 6803 kshift = kshift, 6804 permute = permute, 6805 rescale = False, 6806 operations = operations, 6807 use_prim = use_prim, 6808 add_kpath = add_kpath, 6809 symm_kgrid = symm_kgrid, 6810 ) 6811 6812 #end def __init__ 6813#end class Crystal 6814 6815 6816# test needed 6817class Jellium(Structure): 6818 prefactors = obj() 6819 prefactors.transfer_from({1:2*pi,2:4*pi,3:4./3*pi}) 6820 6821 def __init__(self,charge=None,background_charge=None,cell=None,volume=None,density=None,rs=None,dim=3, 6822 axes=None,kpoints=None,kweights=None,kgrid=None,kshift=None,units=None,tiling=None): 6823 del tiling 6824 if rs!=None: 6825 if not dim in self.prefactors: 6826 self.error('only 1,2, or 3 dimensional jellium is currently supported\n you requested one with dimension {0}'.format(dim)) 6827 #end if 6828 density = 1.0/(self.prefactors[dim]*rs**dim) 6829 #end if 6830 if axes!=None: 6831 cell = axes 6832 #end if 6833 if background_charge!=None: 6834 charge = background_charge 6835 #end if 6836 if cell!=None: 6837 cell = array(cell) 6838 dim = len(cell) 6839 volume = det(cell) 6840 elif volume!=None: 6841 volume = float(volume) 6842 cell = volume**(1./dim)*identity(dim) 6843 #end if 6844 if density!=None: 6845 density = float(density) 6846 if charge is None and volume!=None: 6847 charge = density*volume 6848 elif volume is None and charge!=None: 6849 volume = charge/density 6850 cell = volume**(1./dim)*identity(dim) 6851 #end if 6852 #end if 6853 if charge is None or cell is None: 6854 self.error('not enough information to form jellium structure\n information provided:\n charge: {0}\n cell: {1}\n volume: {2}\n density: {3}\n rs: {4}\n dim: {5}'.format(charge,cell,volume,density,rs,dim)) 6855 #end if 6856 Structure.__init__(self,background_charge=charge,axes=cell,dim=dim,kpoints=kpoints,kweights=kweights,kgrid=kgrid,kshift=kshift,units=units) 6857 #end def __init__ 6858 6859 def density(self): 6860 return self.background_charge/self.volume() 6861 #end def density 6862 6863 def rs(self): 6864 return 1.0/(self.density()*self.prefactors[self.dim])**(1./self.dim) 6865 #end def rs 6866 6867 def tile(self): 6868 self.not_implemented() 6869 #end def tile 6870#end class Jellium 6871 6872 6873 6874 6875 6876 6877 6878 6879# test needed 6880def generate_cell(shape,tiling=None,scale=1.,units=None,struct_type=Structure): 6881 if tiling is None: 6882 tiling = (1,1,1) 6883 #end if 6884 axes = Sobj() 6885 axes.sc = 1.*array([[ 1,0,0],[0, 1,0],[0,0, 1]]) 6886 axes.bcc = .5*array([[-1,1,1],[1,-1,1],[1,1,-1]]) 6887 axes.fcc = .5*array([[ 1,1,0],[1, 0,1],[0,1, 1]]) 6888 ax = dot(diag(tiling),axes[shape]) 6889 center = ax.sum(0)/2 6890 c = Structure(axes=ax,scale=scale,center=center,units=units) 6891 if struct_type!=Structure: 6892 c=c.upcast(struct_type) 6893 #end if 6894 return c 6895#end def generate_cell 6896 6897 6898 6899def generate_structure(type='crystal',*args,**kwargs): 6900 if type=='crystal': 6901 s = generate_crystal_structure(*args,**kwargs) 6902 elif type=='defect': 6903 s = generate_defect_structure(*args,**kwargs) 6904 elif type=='atom': 6905 s = generate_atom_structure(*args,**kwargs) 6906 elif type=='dimer': 6907 s = generate_dimer_structure(*args,**kwargs) 6908 elif type=='trimer': 6909 s = generate_trimer_structure(*args,**kwargs) 6910 elif type=='jellium': 6911 s = generate_jellium_structure(*args,**kwargs) 6912 elif type=='empty': 6913 s = Structure() 6914 elif type=='basic': 6915 s = Structure(*args,**kwargs) 6916 else: 6917 Structure.class_error(str(type)+' is not a valid structure type\noptions are crystal, defect, atom, dimer, trimer, jellium, empty, or basic') 6918 #end if 6919 return s 6920#end def generate_structure 6921 6922 6923 6924 6925# test needed 6926def generate_atom_structure( 6927 atom = None, 6928 units = 'A', 6929 Lbox = None, 6930 skew = 0, 6931 axes = None, 6932 kgrid = (1,1,1), 6933 kshift = (0,0,0), 6934 bconds = tuple('nnn'), 6935 struct_type = Structure 6936 ): 6937 if atom is None: 6938 Structure.class_error('atom must be provided','generate_atom_structure') 6939 #end if 6940 if Lbox!=None: 6941 axes = [[Lbox*(1-skew),0,0],[0,Lbox,0],[0,0,Lbox*(1+skew)]] 6942 #end if 6943 if axes is None: 6944 s = Structure(elem=[atom],pos=[[0,0,0]],units=units,bconds=bconds) 6945 else: 6946 s = Structure(elem=[atom],pos=[[0,0,0]],axes=axes,kgrid=kgrid,kshift=kshift,bconds=bconds,units=units) 6947 s.center_molecule() 6948 #end if 6949 6950 return s 6951#end def generate_atom_structure 6952 6953 6954# test needed 6955def generate_dimer_structure( 6956 dimer = None, 6957 units = 'A', 6958 separation = None, 6959 Lbox = None, 6960 skew = 0, 6961 axes = None, 6962 kgrid = (1,1,1), 6963 kshift = (0,0,0), 6964 bconds = tuple('nnn'), 6965 struct_type = Structure, 6966 axis = 'x' 6967 ): 6968 if dimer is None: 6969 Structure.class_error('dimer atoms must be provided to construct dimer','generate_dimer_structure') 6970 #end if 6971 if separation is None: 6972 Structure.class_error('separation must be provided to construct dimer','generate_dimer_structure') 6973 #end if 6974 if Lbox!=None: 6975 axes = [[Lbox*(1-skew),0,0],[0,Lbox,0],[0,0,Lbox*(1+skew)]] 6976 #end if 6977 if axis=='x': 6978 p2 = [separation,0,0] 6979 elif axis=='y': 6980 p2 = [0,separation,0] 6981 elif axis=='z': 6982 p2 = [0,0,separation] 6983 else: 6984 Structure.class_error('dimer orientation axis must be x,y,z\n you provided: {0}'.format(axis),'generate_dimer_structure') 6985 #end if 6986 if axes is None: 6987 s = Structure(elem=dimer,pos=[[0,0,0],p2],units=units,bconds=bconds) 6988 else: 6989 s = Structure(elem=dimer,pos=[[0,0,0],p2],axes=axes,kgrid=kgrid,kshift=kshift,units=units,bconds=bconds) 6990 s.center_molecule() 6991 #end if 6992 return s 6993#end def generate_dimer_structure 6994 6995 6996# test needed 6997def generate_trimer_structure( 6998 trimer = None, 6999 units = 'A', 7000 separation = None, 7001 angle = None, 7002 Lbox = None, 7003 skew = 0, 7004 axes = None, 7005 kgrid = (1,1,1), 7006 kshift = (0,0,0), 7007 struct_type = Structure, 7008 axis = 'x', 7009 axis2 = 'y', 7010 angular_units = 'degrees', 7011 plane_rot = None 7012 ): 7013 if trimer is None: 7014 Structure.class_error('trimer atoms must be provided to construct trimer','generate_trimer_structure') 7015 #end if 7016 if separation is None: 7017 Structure.class_error('separation must be provided to construct trimer','generate_trimer_structure') 7018 #end if 7019 if len(separation)!=2: 7020 Structure.class_error('two separation distances (atom1-atom2,atom1-atom3) must be provided to construct trimer\nyou provided {0} separation distances'.format(len(separation)),'generate_trimer_structure') 7021 #end if 7022 if angle is None: 7023 Structure.class_error('angle must be provided to construct trimer','generate_trimer_structure') 7024 #end if 7025 if angular_units=='degrees': 7026 angle *= pi/180 7027 elif not angular_units.startswith('rad'): 7028 Structure.class_error('angular units must be degrees or radians\nyou provided: {0}'.format(angular_units),'generate_trimer_structure') 7029 #end if 7030 if axis==axis2: 7031 Structure.class_error('axis and axis2 must be different to define the trimer plane\nyou provided {0} for both'.format(axis),'generate_trimer_structure') 7032 #end if 7033 if Lbox!=None: 7034 axes = [[Lbox*(1-skew),0,0],[0,Lbox,0],[0,0,Lbox*(1+skew)]] 7035 #end if 7036 p1 = [0,0,0] 7037 if axis=='x': 7038 p2 = [separation[0],0,0] 7039 elif axis=='y': 7040 p2 = [0,separation[0],0] 7041 elif axis=='z': 7042 p2 = [0,0,separation[0]] 7043 else: 7044 Structure.class_error('trimer bond1 (atom2-atom1) orientation axis must be x,y,z\n you provided: {0}'.format(axis),'generate_trimer_structure') 7045 #end if 7046 r = separation[1] 7047 c = cos(angle) 7048 s = sin(angle) 7049 axpair = axis+axis2 7050 if axpair=='xy': 7051 p3 = [r*c,r*s,0] 7052 elif axpair=='yx': 7053 p3 = [r*s,r*c,0] 7054 elif axpair=='yz': 7055 p3 = [0,r*c,r*s] 7056 elif axpair=='zy': 7057 p3 = [0,r*s,r*c] 7058 elif axpair=='zx': 7059 p3 = [r*s,0,r*c] 7060 elif axpair=='xz': 7061 p3 = [r*c,0,r*s] 7062 else: 7063 Structure.class_error('trimer bond2 (atom3-atom1) orientation axis must be x,y,z\n you provided: {0}'.format(axis2),'generate_trimer_structure') 7064 #end if 7065 if axes is None: 7066 s = Structure(elem=trimer,pos=[p1,p2,p3],units=units) 7067 else: 7068 s = Structure(elem=trimer,pos=[p1,p2,p3],axes=axes,kgrid=kgrid,kshift=kshift,units=units) 7069 s.center_molecule() 7070 #end if 7071 if plane_rot!=None: 7072 s.rotate_plane(axpair,plane_rot,angular_units) 7073 #end if 7074 return s 7075#end def generate_trimer_structure 7076 7077 7078# test needed 7079def generate_jellium_structure(*args,**kwargs): 7080 return Jellium(*args,**kwargs) 7081#end def generate_jellium_structure 7082 7083 7084 7085 7086def generate_crystal_structure( 7087 lattice = None, 7088 cell = None, 7089 centering = None, 7090 constants = None, 7091 atoms = None, 7092 basis = None, 7093 basis_vectors = None, 7094 tiling = None, 7095 cscale = None, 7096 axes = None, 7097 units = None, 7098 angular_units = 'degrees', 7099 magnetization = None, 7100 mag = None, 7101 kpoints = None, 7102 kweights = None, 7103 kgrid = None, 7104 kshift = (0,0,0), 7105 permute = None, 7106 operations = None, 7107 struct_type = Crystal, 7108 elem = None, 7109 pos = None, 7110 frozen = None, 7111 posu = None, 7112 elem_pos = None, 7113 folded_elem = None, 7114 folded_pos = None, 7115 folded_units = None, 7116 use_prim = None, 7117 add_kpath = False, 7118 symm_kgrid = False, 7119 #legacy inputs 7120 structure = None, 7121 shape = None, 7122 element = None, 7123 scale = None, 7124 ): 7125 7126 if structure is not None: 7127 lattice = structure 7128 #end if 7129 if shape is not None: 7130 cell = shape 7131 #end if 7132 if element is not None: 7133 atoms = element 7134 #end if 7135 if scale is not None: 7136 constants = scale 7137 #end if 7138 7139 #interface for total manual specification 7140 # this is only here because 'crystal' is default and must handle other cases 7141 s = None 7142 has_elem_and_pos = elem is not None and (pos is not None or posu is not None) 7143 has_elem_and_pos |= elem_pos is not None 7144 if has_elem_and_pos: 7145 s = Structure( 7146 axes = axes, 7147 elem = elem, 7148 pos = pos, 7149 units = units, 7150 mag = mag, 7151 frozen = frozen, 7152 magnetization = magnetization, 7153 tiling = tiling, 7154 kpoints = kpoints, 7155 kgrid = kgrid, 7156 kshift = kshift, 7157 permute = permute, 7158 rescale = False, 7159 operations = operations, 7160 posu = posu, 7161 elem_pos = elem_pos, 7162 use_prim = use_prim, 7163 add_kpath = add_kpath, 7164 symm_kgrid = symm_kgrid, 7165 ) 7166 elif isinstance(structure,Structure): 7167 s = structure 7168 if use_prim is not None and use_prim is not False: 7169 s.become_primitive(source=use_prim,add_kpath=add_kpath) 7170 #end if 7171 if tiling is not None: 7172 s = s.tile(tiling) 7173 #end if 7174 if kpoints is not None: 7175 s.add_kpoints(kpoints,kweights) 7176 #end if 7177 if kgrid is not None: 7178 if not symm_kgrid: 7179 s.add_kmesh(kgrid,kshift) 7180 else: 7181 s.add_symmetrized_kmesh(kgrid,kshift) 7182 #end if 7183 #end if 7184 #end if 7185 if s is not None: 7186 # add point group folded molecular system if present 7187 if folded_elem is not None and folded_pos is not None: 7188 if folded_units is None: 7189 folded_units = units 7190 #end if 7191 fs = Structure( 7192 elem = folded_elem, 7193 pos = folded_pos, 7194 units = folded_units, 7195 rescale = False, 7196 ) 7197 s.set_folded(fs) 7198 #end if 7199 return s 7200 #end if 7201 7202 s=Crystal( 7203 lattice = lattice , 7204 cell = cell , 7205 centering = centering , 7206 constants = constants , 7207 atoms = atoms , 7208 basis = basis , 7209 basis_vectors = basis_vectors , 7210 tiling = tiling , 7211 cscale = cscale , 7212 axes = axes , 7213 units = units , 7214 angular_units = angular_units , 7215 frozen = frozen , 7216 mag = mag , 7217 magnetization = magnetization , 7218 kpoints = kpoints , 7219 kgrid = kgrid , 7220 kshift = kshift , 7221 permute = permute , 7222 operations = operations , 7223 elem = elem , 7224 pos = pos , 7225 use_prim = use_prim , 7226 add_kpath = add_kpath , 7227 symm_kgrid = symm_kgrid , 7228 ) 7229 7230 if struct_type!=Crystal: 7231 s=s.upcast(struct_type) 7232 #end if 7233 7234 return s 7235#end def generate_crystal_structure 7236 7237 7238 7239defects = obj( 7240 diamond = obj( 7241 H = obj( 7242 pristine = [[0,0,0]], 7243 defect = [[0,0,0],[.625,.375,.375]] 7244 ), 7245 T = obj( 7246 pristine = [[0,0,0]], 7247 defect = [[0,0,0],[.5,.5,.5]] 7248 ), 7249 X = obj( 7250 pristine = [[.25,.25,.25]], 7251 defect = [[.39,.11,.15],[.11,.39,.15]] 7252 ), 7253 FFC = obj( 7254 #pristine = [[ 0, 0, 0],[.25,.25,.25]], 7255 #defect = [[.151,.151,-.08],[.10,.10,.33]] 7256 pristine = [[ 0, 0, 0],[.25 ,.25 ,.25 ],[.5 ,.5 , 0],[.75 ,.75 ,.25 ],[1.5 ,1.5 ,0 ],[1.75 ,1.75 ,.25 ]], 7257 defect = [[.151,.151,-.081],[.099,.099,.331],[.473,.473,-.059],[.722,.722,.230],[1.528,1.528,.020],[1.777,1.777,.309]] 7258 ) 7259 ) 7260 ) 7261 7262 7263def generate_defect_structure(defect,structure,shape=None,element=None, 7264 tiling=None,scale=1.,kgrid=None,kshift=(0,0,0), 7265 units=None,struct_type=DefectStructure): 7266 if structure in defects: 7267 dstruct = defects[structure] 7268 else: 7269 DefectStructure.class_error('defects for '+structure+' structure have not yet been implemented') 7270 #end if 7271 if defect in dstruct: 7272 drep = dstruct[defect] 7273 else: 7274 DefectStructure.class_error(defect+' defect not found for '+structure+' structure') 7275 #end if 7276 7277 ds = generate_crystal_structure( 7278 structure = structure, 7279 shape = shape, 7280 element = element, 7281 tiling = tiling, 7282 scale = 1.0, 7283 kgrid = kgrid, 7284 kshift = kshift, 7285 units = units, 7286 struct_type = struct_type 7287 ) 7288 7289 ds.replace(drep.pristine,pos=drep.defect) 7290 7291 ds.rescale(scale) 7292 7293 return ds 7294#end def generate_defect_structure 7295 7296 7297def read_structure(filepath,elem=None,format=None): 7298 s = generate_structure('empty') 7299 s.read(filepath,elem=elem,format=format) 7300 return s 7301#end def read_structure 7302 7303 7304 7305 7306 7307 7308if __name__=='__main__': 7309 7310 large = generate_structure( 7311 type = 'crystal', 7312 structure = 'diamond', 7313 cell = 'fcc', 7314 atoms = 'Ge', 7315 constants = 5.639, 7316 units = 'A', 7317 tiling = (2,2,2), 7318 kgrid = (1,1,1), 7319 kshift = (0,0,0), 7320 ) 7321 7322 small = large.folded_structure 7323 7324 print(small.kpoints_unit()) 7325 7326 prim = read_structure('scf.struct.xsf') 7327 prim = get_primitive_cell(structure=prim)['structure'] 7328 tiling = get_band_tiling(structure=prim, kpoints_label = ['L', 'F'], min_volfac=6, max_volfac = 6) 7329 7330 print(tiling) 7331#end if 7332