1# Copyright (C) 2003 CAMP 2# Copyright (C) 2014 R. Warmbier Materials for Energy Research Group, 3# Wits University 4# Please see the accompanying LICENSE file for further information. 5from typing import Tuple 6 7from ase.io import read 8from ase.utils import gcd 9import numpy as np 10 11import _gpaw 12import gpaw.mpi as mpi 13 14 15def frac(f: float, 16 n: int = 2 * 3 * 4 * 5, 17 tol: float = 1e-6) -> Tuple[int, int]: 18 """Convert to fraction. 19 20 >>> frac(0.5) 21 (1, 2) 22 """ 23 if f == 0: 24 return 0, 1 25 x = n * f 26 if abs(x - round(x)) > n * tol: 27 raise ValueError 28 x = int(round(x)) 29 d = gcd(x, n) 30 return x // d, n // d 31 32 33def sfrac(f: float) -> str: 34 """Format as fraction. 35 36 >>> sfrac(0.5) 37 '1/2' 38 >>> sfrac(2 / 3) 39 '2/3' 40 >>> sfrac(0) 41 '0' 42 """ 43 if f == 0: 44 return '0' 45 return '%d/%d' % frac(f) 46 47 48class Symmetry: 49 """Interface class for determination of symmetry, point and space groups. 50 51 It also provides to apply symmetry operations to kpoint grids, 52 wavefunctions and forces. 53 """ 54 def __init__(self, id_a, cell_cv, pbc_c=np.ones(3, bool), tolerance=1e-7, 55 point_group=True, time_reversal=True, symmorphic=True, 56 allow_invert_aperiodic_axes=True): 57 """Construct symmetry object. 58 59 Parameters: 60 61 id_a: list of int 62 Numbered atomic types 63 cell_cv: array(3,3), float 64 Cartesian lattice vectors 65 pbc_c: array(3), bool 66 Periodic boundary conditions. 67 tolerance: float 68 Tolerance for symmetry determination. 69 symmorphic: bool 70 Switch for the use of non-symmorphic symmetries aka: symmetries 71 with fractional translations. Default is to use only symmorphic 72 symmetries. 73 point_group: bool 74 Use point-group symmetries. 75 time_reversal: bool 76 Use time-reversal symmetry. 77 tolerance: float 78 Relative tolerance. 79 80 Attributes: 81 82 op_scc: 83 Array of rotation matrices 84 ft_sc: 85 Array of fractional translation vectors 86 a_sa: 87 Array of atomic indices after symmetry operation 88 has_inversion: 89 (bool) Have inversion 90 """ 91 92 self.id_a = id_a 93 self.cell_cv = np.array(cell_cv, float) 94 assert self.cell_cv.shape == (3, 3) 95 self.pbc_c = np.array(pbc_c, bool) 96 self.tol = tolerance 97 self.symmorphic = symmorphic 98 self.point_group = point_group 99 self.time_reversal = time_reversal 100 101 self.op_scc = np.identity(3, int).reshape((1, 3, 3)) 102 self.ft_sc = np.zeros((1, 3)) 103 self.a_sa = np.arange(len(id_a)).reshape((1, -1)) 104 self.has_inversion = False 105 self.gcd_c = np.ones(3, int) 106 107 # For reading old gpw-files: 108 self.allow_invert_aperiodic_axes = allow_invert_aperiodic_axes 109 110 def analyze(self, spos_ac): 111 """Determine list of symmetry operations. 112 113 First determine all symmetry operations of the cell. Then call 114 ``prune_symmetries`` to remove those symmetries that are not satisfied 115 by the atoms. 116 117 It is not mandatory to call this method. If not called, only 118 time reversal symmetry may be used. 119 """ 120 if self.point_group: 121 self.find_lattice_symmetry() 122 self.prune_symmetries_atoms(spos_ac) 123 124 def find_lattice_symmetry(self): 125 """Determine list of symmetry operations.""" 126 # Symmetry operations as matrices in 123 basis. 127 # Operation is a 3x3 matrix, with possible elements -1, 0, 1, thus 128 # there are 3**9 = 19683 possible matrices: 129 combinations = 1 - np.indices([3] * 9) 130 U_scc = combinations.reshape((3, 3, 3**9)).transpose((2, 0, 1)) 131 132 # The metric of the cell should be conserved after applying 133 # the operation: 134 metric_cc = self.cell_cv.dot(self.cell_cv.T) 135 metric_scc = np.einsum('sij, jk, slk -> sil', 136 U_scc, metric_cc, U_scc, 137 optimize=True) 138 mask_s = abs(metric_scc - metric_cc).sum(2).sum(1) <= self.tol 139 U_scc = U_scc[mask_s] 140 141 # Operation must not swap axes that don't have same PBC: 142 pbc_cc = np.logical_xor.outer(self.pbc_c, self.pbc_c) 143 mask_s = ~U_scc[:, pbc_cc].any(axis=1) 144 U_scc = U_scc[mask_s] 145 146 if not self.allow_invert_aperiodic_axes: 147 # Operation must not invert axes that are not periodic: 148 mask_s = (U_scc[:, np.diag(~self.pbc_c)] == 1).all(axis=1) 149 U_scc = U_scc[mask_s] 150 151 self.op_scc = U_scc 152 self.ft_sc = np.zeros((len(self.op_scc), 3)) 153 154 def prune_symmetries_atoms(self, spos_ac): 155 """Remove symmetries that are not satisfied by the atoms.""" 156 157 if len(spos_ac) == 0: 158 self.a_sa = np.zeros((len(self.op_scc), 0), int) 159 return 160 161 # Build lists of atom numbers for each type of atom - one 162 # list for each combination of atomic number, setup type, 163 # magnetic moment and basis set: 164 a_ij = {} 165 for a, id in enumerate(self.id_a): 166 if id in a_ij: 167 a_ij[id].append(a) 168 else: 169 a_ij[id] = [a] 170 171 a_j = a_ij[self.id_a[0]] # just pick the first species 172 173 # if supercell disable fractional translations: 174 if not self.symmorphic: 175 op_cc = np.identity(3, int) 176 ftrans_sc = spos_ac[a_j[1:]] - spos_ac[a_j[0]] 177 ftrans_sc -= np.rint(ftrans_sc) 178 for ft_c in ftrans_sc: 179 a_a = self.check_one_symmetry(spos_ac, op_cc, ft_c, a_ij) 180 if a_a is not None: 181 self.symmorphic = True 182 break 183 184 symmetries = [] 185 ftsymmetries = [] 186 187 # go through all possible symmetry operations 188 for op_cc in self.op_scc: 189 # first ignore fractional translations 190 a_a = self.check_one_symmetry(spos_ac, op_cc, [0, 0, 0], a_ij) 191 if a_a is not None: 192 symmetries.append((op_cc, [0, 0, 0], a_a)) 193 elif not self.symmorphic: 194 # check fractional translations 195 sposrot_ac = np.dot(spos_ac, op_cc) 196 ftrans_jc = sposrot_ac[a_j] - spos_ac[a_j[0]] 197 ftrans_jc -= np.rint(ftrans_jc) 198 for ft_c in ftrans_jc: 199 try: 200 nom_c, denom_c = np.array([frac(ft, tol=self.tol) 201 for ft in ft_c]).T 202 except ValueError: 203 continue 204 ft_c = nom_c / denom_c 205 a_a = self.check_one_symmetry(spos_ac, op_cc, ft_c, a_ij) 206 if a_a is not None: 207 ftsymmetries.append((op_cc, ft_c, a_a)) 208 for c, d in enumerate(denom_c): 209 if self.gcd_c[c] % d != 0: 210 self.gcd_c[c] *= d 211 212 # Add symmetry operations with fractional translations at the end: 213 symmetries.extend(ftsymmetries) 214 self.op_scc = np.array([sym[0] for sym in symmetries]) 215 self.ft_sc = np.array([sym[1] for sym in symmetries]) 216 self.a_sa = np.array([sym[2] for sym in symmetries]) 217 218 inv_cc = -np.eye(3, dtype=int) 219 self.has_inversion = (self.op_scc == inv_cc).all(2).all(1).any() 220 221 def check_one_symmetry(self, spos_ac, op_cc, ft_c, a_ij): 222 """Checks whether atoms satisfy one given symmetry operation.""" 223 224 a_a = np.zeros(len(spos_ac), int) 225 for a_j in a_ij.values(): 226 spos_jc = spos_ac[a_j] 227 for a in a_j: 228 spos_c = np.dot(spos_ac[a], op_cc) 229 sdiff_jc = spos_c - spos_jc - ft_c 230 sdiff_jc -= sdiff_jc.round() 231 indices = np.where(abs(sdiff_jc).max(1) < self.tol)[0] 232 if len(indices) == 1: 233 j = indices[0] 234 a_a[a] = a_j[j] 235 else: 236 assert len(indices) == 0 237 return 238 239 return a_a 240 241 def check(self, spos_ac): 242 """Check if positions satisfy symmetry operations.""" 243 244 nsymold = len(self.op_scc) 245 self.prune_symmetries_atoms(spos_ac) 246 if len(self.op_scc) < nsymold: 247 raise RuntimeError('Broken symmetry!') 248 249 def reduce(self, bzk_kc, comm=None): 250 """Reduce k-points to irreducible part of the BZ. 251 252 Returns the irreducible k-points and the weights and other stuff. 253 254 """ 255 nbzkpts = len(bzk_kc) 256 U_scc = self.op_scc 257 nsym = len(U_scc) 258 259 time_reversal = self.time_reversal and not self.has_inversion 260 bz2bz_ks = map_k_points_fast(bzk_kc, U_scc, time_reversal, 261 comm, self.tol) 262 263 bz2bz_k = -np.ones(nbzkpts + 1, int) 264 ibz2bz_k = [] 265 for k in range(nbzkpts - 1, -1, -1): 266 # Reverse order looks more natural 267 if bz2bz_k[k] == -1: 268 bz2bz_k[bz2bz_ks[k]] = k 269 ibz2bz_k.append(k) 270 ibz2bz_k = np.array(ibz2bz_k[::-1]) 271 bz2bz_k = bz2bz_k[:-1].copy() 272 273 bz2ibz_k = np.empty(nbzkpts, int) 274 bz2ibz_k[ibz2bz_k] = np.arange(len(ibz2bz_k)) 275 bz2ibz_k = bz2ibz_k[bz2bz_k] 276 277 weight_k = np.bincount(bz2ibz_k) * (1.0 / nbzkpts) 278 279 # Symmetry operation mapping IBZ to BZ: 280 sym_k = np.empty(nbzkpts, int) 281 for k in range(nbzkpts): 282 # We pick the first one found: 283 try: 284 sym_k[k] = np.where(bz2bz_ks[bz2bz_k[k]] == k)[0][0] 285 except IndexError: 286 print(nbzkpts) 287 print(k) 288 print(bz2bz_k) 289 print(bz2bz_ks[bz2bz_k[k]]) 290 print(np.shape(np.where(bz2bz_ks[bz2bz_k[k]] == k))) 291 print(bz2bz_k[k]) 292 print(bz2bz_ks[bz2bz_k[k]] == k) 293 raise 294 295 # Time-reversal symmetry used on top of the point group operation: 296 if time_reversal: 297 time_reversal_k = sym_k >= nsym 298 sym_k %= nsym 299 else: 300 time_reversal_k = np.zeros(nbzkpts, bool) 301 302 assert (ibz2bz_k[bz2ibz_k] == bz2bz_k).all() 303 for k in range(nbzkpts): 304 sign = 1 - 2 * time_reversal_k[k] 305 dq_c = (np.dot(U_scc[sym_k[k]], bzk_kc[bz2bz_k[k]]) - 306 sign * bzk_kc[k]) 307 dq_c -= dq_c.round() 308 assert abs(dq_c).max() < 1e-10 309 310 return (bzk_kc[ibz2bz_k], weight_k, 311 sym_k, time_reversal_k, bz2ibz_k, ibz2bz_k, bz2bz_ks) 312 313 def check_grid(self, N_c) -> bool: 314 """Check that symmetries are comensurate with grid.""" 315 for s, (U_cc, ft_c) in enumerate(zip(self.op_scc, self.ft_sc)): 316 t_c = ft_c * N_c 317 # Make sure all grid-points map onto another grid-point: 318 if (((N_c * U_cc).T % N_c).any() or 319 not np.allclose(t_c, t_c.round())): 320 return False 321 return True 322 323 def symmetrize(self, a, gd): 324 """Symmetrize array.""" 325 gd.symmetrize(a, self.op_scc, self.ft_sc) 326 327 def symmetrize_positions(self, spos_ac): 328 """Symmetrizes the atomic positions.""" 329 spos_tmp_ac = np.zeros_like(spos_ac) 330 spos_new_ac = np.zeros_like(spos_ac) 331 for i, op_cc in enumerate(self.op_scc): 332 spos_tmp_ac[:] = 0. 333 for a in range(len(spos_ac)): 334 spos_c = np.dot(spos_ac[a], op_cc) - self.ft_sc[i] 335 # Bring back the negative ones: 336 spos_c = spos_c - np.floor(spos_c + 1e-5) 337 spos_tmp_ac[self.a_sa[i][a]] += spos_c 338 spos_new_ac += spos_tmp_ac 339 340 spos_new_ac /= len(self.op_scc) 341 return spos_new_ac 342 343 def symmetrize_wavefunction(self, a_g, kibz_c, kbz_c, op_cc, 344 time_reversal): 345 """Generate Bloch function from symmetry related function in the IBZ. 346 347 a_g: ndarray 348 Array with Bloch function from the irreducible BZ. 349 kibz_c: ndarray 350 Corresponing k-point coordinates. 351 kbz_c: ndarray 352 K-point coordinates of the symmetry related k-point. 353 op_cc: ndarray 354 Point group operation connecting the two k-points. 355 time-reversal: bool 356 Time-reversal symmetry required in addition to the point group 357 symmetry to connect the two k-points. 358 """ 359 360 # Identity 361 if (np.abs(op_cc - np.eye(3, dtype=int)) < 1e-10).all(): 362 if time_reversal: 363 return a_g.conj() 364 else: 365 return a_g 366 # Inversion symmetry 367 elif (np.abs(op_cc + np.eye(3, dtype=int)) < 1e-10).all(): 368 return a_g.conj() 369 # General point group symmetry 370 else: 371 import _gpaw 372 b_g = np.zeros_like(a_g) 373 if time_reversal: 374 # assert abs(np.dot(op_cc, kibz_c) - -kbz_c) < tol 375 _gpaw.symmetrize_wavefunction(a_g, b_g, op_cc.T.copy(), 376 kibz_c, -kbz_c) 377 return b_g.conj() 378 else: 379 # assert abs(np.dot(op_cc, kibz_c) - kbz_c) < tol 380 _gpaw.symmetrize_wavefunction(a_g, b_g, op_cc.T.copy(), 381 kibz_c, kbz_c) 382 return b_g 383 384 def symmetrize_forces(self, F0_av): 385 """Symmetrize forces.""" 386 F_ac = np.zeros_like(F0_av) 387 for map_a, op_cc in zip(self.a_sa, self.op_scc): 388 op_vv = np.dot(np.linalg.inv(self.cell_cv), 389 np.dot(op_cc, self.cell_cv)) 390 for a1, a2 in enumerate(map_a): 391 F_ac[a2] += np.dot(F0_av[a1], op_vv) 392 return F_ac / len(self.op_scc) 393 394 def __str__(self): 395 n = len(self.op_scc) 396 nft = self.ft_sc.any(1).sum() 397 lines = ['Symmetries present (total): {0}'.format(n)] 398 if not self.symmorphic: 399 lines.append( 400 'Symmetries with fractional translations: {0}'.format(nft)) 401 402 # X-Y grid of symmetry matrices: 403 404 lines.append('') 405 nx = 6 if self.symmorphic else 3 406 ns = len(self.op_scc) 407 y = 0 408 for y in range((ns + nx - 1) // nx): 409 for c in range(3): 410 line = '' 411 for x in range(nx): 412 s = x + y * nx 413 if s == ns: 414 break 415 op_c = self.op_scc[s, c] 416 ft = self.ft_sc[s, c] 417 line += ' (%2d %2d %2d)' % tuple(op_c) 418 if not self.symmorphic: 419 line += ' + (%4s)' % sfrac(ft) 420 lines.append(line) 421 lines.append('') 422 return '\n'.join(lines) 423 424 425def map_k_points(bzk_kc, U_scc, time_reversal, comm=None, tol=1e-11): 426 """Find symmetry relations between k-points. 427 428 This is a Python-wrapper for a C-function that does the hard work 429 which is distributed over comm. 430 431 The map bz2bz_ks is returned. If there is a k2 for which:: 432 433 = _ _ _ 434 U q = q + N, 435 s k1 k2 436 437 where N is a vector of integers, then bz2bz_ks[k1, s] = k2, otherwise 438 if there is a k2 for which:: 439 440 = _ _ _ 441 U q = -q + N, 442 s k1 k2 443 444 then bz2bz_ks[k1, s + nsym] = k2, where nsym = len(U_scc). Otherwise 445 bz2bz_ks[k1, s] = -1. 446 """ 447 448 if comm is None or isinstance(comm, mpi.DryRunCommunicator): 449 comm = mpi.serial_comm 450 451 nbzkpts = len(bzk_kc) 452 ka = nbzkpts * comm.rank // comm.size 453 kb = nbzkpts * (comm.rank + 1) // comm.size 454 assert comm.sum(kb - ka) == nbzkpts 455 456 if time_reversal: 457 U_scc = np.concatenate([U_scc, -U_scc]) 458 459 bz2bz_ks = np.zeros((nbzkpts, len(U_scc)), int) 460 bz2bz_ks[ka:kb] = -1 461 _gpaw.map_k_points(np.ascontiguousarray(bzk_kc), 462 np.ascontiguousarray(U_scc), tol, bz2bz_ks, ka, kb) 463 comm.sum(bz2bz_ks) 464 return bz2bz_ks 465 466 467def map_k_points_fast(bzk_kc, U_scc, time_reversal, comm=None, tol=1e-7): 468 """Find symmetry relations between k-points. 469 470 Performs the same task as map_k_points(), but much faster. 471 This is achieved by finding the symmetry related kpoints using 472 lexical sorting instead of brute force searching. 473 474 bzk_kc: ndarray 475 kpoint coordinates. 476 U_scc: ndarray 477 Symmetry operations 478 time_reversal: Bool 479 Use time reversal symmetry in mapping. 480 comm: 481 Communicator 482 tol: float 483 When kpoint are closer than tol, they are 484 considered to be identical. 485 """ 486 487 nbzkpts = len(bzk_kc) 488 489 if time_reversal: 490 U_scc = np.concatenate([U_scc, -U_scc]) 491 492 bz2bz_ks = np.zeros((nbzkpts, len(U_scc)), int) 493 bz2bz_ks[:] = -1 494 495 for s, U_cc in enumerate(U_scc): 496 # Find mapped kpoints 497 Ubzk_kc = np.dot(bzk_kc, U_cc.T) 498 499 # Do some work on the input 500 k_kc = np.concatenate([bzk_kc, Ubzk_kc]) 501 k_kc = np.mod(np.mod(k_kc, 1), 1) 502 aglomerate_points(k_kc, tol) 503 k_kc = k_kc.round(-np.log10(tol).astype(int)) 504 k_kc = np.mod(k_kc, 1) 505 506 # Find the lexicographical order 507 order = np.lexsort(k_kc.T) 508 k_kc = k_kc[order] 509 diff_kc = np.diff(k_kc, axis=0) 510 equivalentpairs_k = np.array((diff_kc == 0).all(1), 511 bool) 512 513 # Mapping array. 514 orders = np.array([order[:-1][equivalentpairs_k], 515 order[1:][equivalentpairs_k]]) 516 517 # This has to be true. 518 assert (orders[0] < nbzkpts).all() 519 assert (orders[1] >= nbzkpts).all() 520 bz2bz_ks[orders[1] - nbzkpts, s] = orders[0] 521 522 return bz2bz_ks 523 524 525def aglomerate_points(k_kc, tol): 526 nd = k_kc.shape[1] 527 nbzkpts = len(k_kc) 528 inds_kc = np.argsort(k_kc, axis=0) 529 for c in range(nd): 530 sk_k = k_kc[inds_kc[:, c], c] 531 dk_k = np.diff(sk_k) 532 533 # Partition the kpoints into groups 534 pt_K = np.argwhere(dk_k > tol)[:, 0] 535 pt_K = np.append(np.append(0, pt_K + 1), 2 * nbzkpts) 536 for i in range(len(pt_K) - 1): 537 k_kc[inds_kc[pt_K[i]:pt_K[i + 1], c], 538 c] = k_kc[inds_kc[pt_K[i], c], c] 539 540 541def atoms2symmetry(atoms, id_a=None, tolerance=1e-7): 542 """Create symmetry object from atoms object.""" 543 if id_a is None: 544 id_a = atoms.get_atomic_numbers() 545 symmetry = Symmetry(id_a, atoms.cell, atoms.pbc, 546 symmorphic=False, 547 time_reversal=False, 548 tolerance=tolerance) 549 symmetry.analyze(atoms.get_scaled_positions()) 550 return symmetry 551 552 553class CLICommand: 554 """Analyse symmetry.""" 555 556 @staticmethod 557 def add_arguments(parser): 558 parser.add_argument('filename') 559 560 @staticmethod 561 def run(args): 562 atoms = read(args.filename) 563 symmetry = atoms2symmetry(atoms) 564 print(symmetry) 565