1# Copyright (C) 2003  CAMP
2# Copyright (C) 2014 R. Warmbier Materials for Energy Research Group,
3# Wits University
4# Please see the accompanying LICENSE file for further information.
5from typing import Tuple
6
7from ase.io import read
8from ase.utils import gcd
9import numpy as np
10
11import _gpaw
12import gpaw.mpi as mpi
13
14
15def frac(f: float,
16         n: int = 2 * 3 * 4 * 5,
17         tol: float = 1e-6) -> Tuple[int, int]:
18    """Convert to fraction.
19
20    >>> frac(0.5)
21    (1, 2)
22    """
23    if f == 0:
24        return 0, 1
25    x = n * f
26    if abs(x - round(x)) > n * tol:
27        raise ValueError
28    x = int(round(x))
29    d = gcd(x, n)
30    return x // d, n // d
31
32
33def sfrac(f: float) -> str:
34    """Format as fraction.
35
36    >>> sfrac(0.5)
37    '1/2'
38    >>> sfrac(2 / 3)
39    '2/3'
40    >>> sfrac(0)
41    '0'
42    """
43    if f == 0:
44        return '0'
45    return '%d/%d' % frac(f)
46
47
48class Symmetry:
49    """Interface class for determination of symmetry, point and space groups.
50
51    It also provides to apply symmetry operations to kpoint grids,
52    wavefunctions and forces.
53    """
54    def __init__(self, id_a, cell_cv, pbc_c=np.ones(3, bool), tolerance=1e-7,
55                 point_group=True, time_reversal=True, symmorphic=True,
56                 allow_invert_aperiodic_axes=True):
57        """Construct symmetry object.
58
59        Parameters:
60
61        id_a: list of int
62            Numbered atomic types
63        cell_cv: array(3,3), float
64            Cartesian lattice vectors
65        pbc_c: array(3), bool
66            Periodic boundary conditions.
67        tolerance: float
68            Tolerance for symmetry determination.
69        symmorphic: bool
70            Switch for the use of non-symmorphic symmetries aka: symmetries
71            with fractional translations.  Default is to use only symmorphic
72            symmetries.
73        point_group: bool
74            Use point-group symmetries.
75        time_reversal: bool
76            Use time-reversal symmetry.
77        tolerance: float
78            Relative tolerance.
79
80        Attributes:
81
82        op_scc:
83            Array of rotation matrices
84        ft_sc:
85            Array of fractional translation vectors
86        a_sa:
87            Array of atomic indices after symmetry operation
88        has_inversion:
89            (bool) Have inversion
90        """
91
92        self.id_a = id_a
93        self.cell_cv = np.array(cell_cv, float)
94        assert self.cell_cv.shape == (3, 3)
95        self.pbc_c = np.array(pbc_c, bool)
96        self.tol = tolerance
97        self.symmorphic = symmorphic
98        self.point_group = point_group
99        self.time_reversal = time_reversal
100
101        self.op_scc = np.identity(3, int).reshape((1, 3, 3))
102        self.ft_sc = np.zeros((1, 3))
103        self.a_sa = np.arange(len(id_a)).reshape((1, -1))
104        self.has_inversion = False
105        self.gcd_c = np.ones(3, int)
106
107        # For reading old gpw-files:
108        self.allow_invert_aperiodic_axes = allow_invert_aperiodic_axes
109
110    def analyze(self, spos_ac):
111        """Determine list of symmetry operations.
112
113        First determine all symmetry operations of the cell. Then call
114        ``prune_symmetries`` to remove those symmetries that are not satisfied
115        by the atoms.
116
117        It is not mandatory to call this method.  If not called, only
118        time reversal symmetry may be used.
119        """
120        if self.point_group:
121            self.find_lattice_symmetry()
122            self.prune_symmetries_atoms(spos_ac)
123
124    def find_lattice_symmetry(self):
125        """Determine list of symmetry operations."""
126        # Symmetry operations as matrices in 123 basis.
127        # Operation is a 3x3 matrix, with possible elements -1, 0, 1, thus
128        # there are 3**9 = 19683 possible matrices:
129        combinations = 1 - np.indices([3] * 9)
130        U_scc = combinations.reshape((3, 3, 3**9)).transpose((2, 0, 1))
131
132        # The metric of the cell should be conserved after applying
133        # the operation:
134        metric_cc = self.cell_cv.dot(self.cell_cv.T)
135        metric_scc = np.einsum('sij, jk, slk -> sil',
136                               U_scc, metric_cc, U_scc,
137                               optimize=True)
138        mask_s = abs(metric_scc - metric_cc).sum(2).sum(1) <= self.tol
139        U_scc = U_scc[mask_s]
140
141        # Operation must not swap axes that don't have same PBC:
142        pbc_cc = np.logical_xor.outer(self.pbc_c, self.pbc_c)
143        mask_s = ~U_scc[:, pbc_cc].any(axis=1)
144        U_scc = U_scc[mask_s]
145
146        if not self.allow_invert_aperiodic_axes:
147            # Operation must not invert axes that are not periodic:
148            mask_s = (U_scc[:, np.diag(~self.pbc_c)] == 1).all(axis=1)
149            U_scc = U_scc[mask_s]
150
151        self.op_scc = U_scc
152        self.ft_sc = np.zeros((len(self.op_scc), 3))
153
154    def prune_symmetries_atoms(self, spos_ac):
155        """Remove symmetries that are not satisfied by the atoms."""
156
157        if len(spos_ac) == 0:
158            self.a_sa = np.zeros((len(self.op_scc), 0), int)
159            return
160
161        # Build lists of atom numbers for each type of atom - one
162        # list for each combination of atomic number, setup type,
163        # magnetic moment and basis set:
164        a_ij = {}
165        for a, id in enumerate(self.id_a):
166            if id in a_ij:
167                a_ij[id].append(a)
168            else:
169                a_ij[id] = [a]
170
171        a_j = a_ij[self.id_a[0]]  # just pick the first species
172
173        # if supercell disable fractional translations:
174        if not self.symmorphic:
175            op_cc = np.identity(3, int)
176            ftrans_sc = spos_ac[a_j[1:]] - spos_ac[a_j[0]]
177            ftrans_sc -= np.rint(ftrans_sc)
178            for ft_c in ftrans_sc:
179                a_a = self.check_one_symmetry(spos_ac, op_cc, ft_c, a_ij)
180                if a_a is not None:
181                    self.symmorphic = True
182                    break
183
184        symmetries = []
185        ftsymmetries = []
186
187        # go through all possible symmetry operations
188        for op_cc in self.op_scc:
189            # first ignore fractional translations
190            a_a = self.check_one_symmetry(spos_ac, op_cc, [0, 0, 0], a_ij)
191            if a_a is not None:
192                symmetries.append((op_cc, [0, 0, 0], a_a))
193            elif not self.symmorphic:
194                # check fractional translations
195                sposrot_ac = np.dot(spos_ac, op_cc)
196                ftrans_jc = sposrot_ac[a_j] - spos_ac[a_j[0]]
197                ftrans_jc -= np.rint(ftrans_jc)
198                for ft_c in ftrans_jc:
199                    try:
200                        nom_c, denom_c = np.array([frac(ft, tol=self.tol)
201                                                   for ft in ft_c]).T
202                    except ValueError:
203                        continue
204                    ft_c = nom_c / denom_c
205                    a_a = self.check_one_symmetry(spos_ac, op_cc, ft_c, a_ij)
206                    if a_a is not None:
207                        ftsymmetries.append((op_cc, ft_c, a_a))
208                        for c, d in enumerate(denom_c):
209                            if self.gcd_c[c] % d != 0:
210                                self.gcd_c[c] *= d
211
212        # Add symmetry operations with fractional translations at the end:
213        symmetries.extend(ftsymmetries)
214        self.op_scc = np.array([sym[0] for sym in symmetries])
215        self.ft_sc = np.array([sym[1] for sym in symmetries])
216        self.a_sa = np.array([sym[2] for sym in symmetries])
217
218        inv_cc = -np.eye(3, dtype=int)
219        self.has_inversion = (self.op_scc == inv_cc).all(2).all(1).any()
220
221    def check_one_symmetry(self, spos_ac, op_cc, ft_c, a_ij):
222        """Checks whether atoms satisfy one given symmetry operation."""
223
224        a_a = np.zeros(len(spos_ac), int)
225        for a_j in a_ij.values():
226            spos_jc = spos_ac[a_j]
227            for a in a_j:
228                spos_c = np.dot(spos_ac[a], op_cc)
229                sdiff_jc = spos_c - spos_jc - ft_c
230                sdiff_jc -= sdiff_jc.round()
231                indices = np.where(abs(sdiff_jc).max(1) < self.tol)[0]
232                if len(indices) == 1:
233                    j = indices[0]
234                    a_a[a] = a_j[j]
235                else:
236                    assert len(indices) == 0
237                    return
238
239        return a_a
240
241    def check(self, spos_ac):
242        """Check if positions satisfy symmetry operations."""
243
244        nsymold = len(self.op_scc)
245        self.prune_symmetries_atoms(spos_ac)
246        if len(self.op_scc) < nsymold:
247            raise RuntimeError('Broken symmetry!')
248
249    def reduce(self, bzk_kc, comm=None):
250        """Reduce k-points to irreducible part of the BZ.
251
252        Returns the irreducible k-points and the weights and other stuff.
253
254        """
255        nbzkpts = len(bzk_kc)
256        U_scc = self.op_scc
257        nsym = len(U_scc)
258
259        time_reversal = self.time_reversal and not self.has_inversion
260        bz2bz_ks = map_k_points_fast(bzk_kc, U_scc, time_reversal,
261                                     comm, self.tol)
262
263        bz2bz_k = -np.ones(nbzkpts + 1, int)
264        ibz2bz_k = []
265        for k in range(nbzkpts - 1, -1, -1):
266            # Reverse order looks more natural
267            if bz2bz_k[k] == -1:
268                bz2bz_k[bz2bz_ks[k]] = k
269                ibz2bz_k.append(k)
270        ibz2bz_k = np.array(ibz2bz_k[::-1])
271        bz2bz_k = bz2bz_k[:-1].copy()
272
273        bz2ibz_k = np.empty(nbzkpts, int)
274        bz2ibz_k[ibz2bz_k] = np.arange(len(ibz2bz_k))
275        bz2ibz_k = bz2ibz_k[bz2bz_k]
276
277        weight_k = np.bincount(bz2ibz_k) * (1.0 / nbzkpts)
278
279        # Symmetry operation mapping IBZ to BZ:
280        sym_k = np.empty(nbzkpts, int)
281        for k in range(nbzkpts):
282            # We pick the first one found:
283            try:
284                sym_k[k] = np.where(bz2bz_ks[bz2bz_k[k]] == k)[0][0]
285            except IndexError:
286                print(nbzkpts)
287                print(k)
288                print(bz2bz_k)
289                print(bz2bz_ks[bz2bz_k[k]])
290                print(np.shape(np.where(bz2bz_ks[bz2bz_k[k]] == k)))
291                print(bz2bz_k[k])
292                print(bz2bz_ks[bz2bz_k[k]] == k)
293                raise
294
295        # Time-reversal symmetry used on top of the point group operation:
296        if time_reversal:
297            time_reversal_k = sym_k >= nsym
298            sym_k %= nsym
299        else:
300            time_reversal_k = np.zeros(nbzkpts, bool)
301
302        assert (ibz2bz_k[bz2ibz_k] == bz2bz_k).all()
303        for k in range(nbzkpts):
304            sign = 1 - 2 * time_reversal_k[k]
305            dq_c = (np.dot(U_scc[sym_k[k]], bzk_kc[bz2bz_k[k]]) -
306                    sign * bzk_kc[k])
307            dq_c -= dq_c.round()
308            assert abs(dq_c).max() < 1e-10
309
310        return (bzk_kc[ibz2bz_k], weight_k,
311                sym_k, time_reversal_k, bz2ibz_k, ibz2bz_k, bz2bz_ks)
312
313    def check_grid(self, N_c) -> bool:
314        """Check that symmetries are comensurate with grid."""
315        for s, (U_cc, ft_c) in enumerate(zip(self.op_scc, self.ft_sc)):
316            t_c = ft_c * N_c
317            # Make sure all grid-points map onto another grid-point:
318            if (((N_c * U_cc).T % N_c).any() or
319                not np.allclose(t_c, t_c.round())):
320                return False
321        return True
322
323    def symmetrize(self, a, gd):
324        """Symmetrize array."""
325        gd.symmetrize(a, self.op_scc, self.ft_sc)
326
327    def symmetrize_positions(self, spos_ac):
328        """Symmetrizes the atomic positions."""
329        spos_tmp_ac = np.zeros_like(spos_ac)
330        spos_new_ac = np.zeros_like(spos_ac)
331        for i, op_cc in enumerate(self.op_scc):
332            spos_tmp_ac[:] = 0.
333            for a in range(len(spos_ac)):
334                spos_c = np.dot(spos_ac[a], op_cc) - self.ft_sc[i]
335                # Bring back the negative ones:
336                spos_c = spos_c - np.floor(spos_c + 1e-5)
337                spos_tmp_ac[self.a_sa[i][a]] += spos_c
338            spos_new_ac += spos_tmp_ac
339
340        spos_new_ac /= len(self.op_scc)
341        return spos_new_ac
342
343    def symmetrize_wavefunction(self, a_g, kibz_c, kbz_c, op_cc,
344                                time_reversal):
345        """Generate Bloch function from symmetry related function in the IBZ.
346
347        a_g: ndarray
348            Array with Bloch function from the irreducible BZ.
349        kibz_c: ndarray
350            Corresponing k-point coordinates.
351        kbz_c: ndarray
352            K-point coordinates of the symmetry related k-point.
353        op_cc: ndarray
354            Point group operation connecting the two k-points.
355        time-reversal: bool
356            Time-reversal symmetry required in addition to the point group
357            symmetry to connect the two k-points.
358        """
359
360        # Identity
361        if (np.abs(op_cc - np.eye(3, dtype=int)) < 1e-10).all():
362            if time_reversal:
363                return a_g.conj()
364            else:
365                return a_g
366        # Inversion symmetry
367        elif (np.abs(op_cc + np.eye(3, dtype=int)) < 1e-10).all():
368            return a_g.conj()
369        # General point group symmetry
370        else:
371            import _gpaw
372            b_g = np.zeros_like(a_g)
373            if time_reversal:
374                # assert abs(np.dot(op_cc, kibz_c) - -kbz_c) < tol
375                _gpaw.symmetrize_wavefunction(a_g, b_g, op_cc.T.copy(),
376                                              kibz_c, -kbz_c)
377                return b_g.conj()
378            else:
379                # assert abs(np.dot(op_cc, kibz_c) - kbz_c) < tol
380                _gpaw.symmetrize_wavefunction(a_g, b_g, op_cc.T.copy(),
381                                              kibz_c, kbz_c)
382                return b_g
383
384    def symmetrize_forces(self, F0_av):
385        """Symmetrize forces."""
386        F_ac = np.zeros_like(F0_av)
387        for map_a, op_cc in zip(self.a_sa, self.op_scc):
388            op_vv = np.dot(np.linalg.inv(self.cell_cv),
389                           np.dot(op_cc, self.cell_cv))
390            for a1, a2 in enumerate(map_a):
391                F_ac[a2] += np.dot(F0_av[a1], op_vv)
392        return F_ac / len(self.op_scc)
393
394    def __str__(self):
395        n = len(self.op_scc)
396        nft = self.ft_sc.any(1).sum()
397        lines = ['Symmetries present (total): {0}'.format(n)]
398        if not self.symmorphic:
399            lines.append(
400                'Symmetries with fractional translations: {0}'.format(nft))
401
402        # X-Y grid of symmetry matrices:
403
404        lines.append('')
405        nx = 6 if self.symmorphic else 3
406        ns = len(self.op_scc)
407        y = 0
408        for y in range((ns + nx - 1) // nx):
409            for c in range(3):
410                line = ''
411                for x in range(nx):
412                    s = x + y * nx
413                    if s == ns:
414                        break
415                    op_c = self.op_scc[s, c]
416                    ft = self.ft_sc[s, c]
417                    line += '  (%2d %2d %2d)' % tuple(op_c)
418                    if not self.symmorphic:
419                        line += ' + (%4s)' % sfrac(ft)
420                lines.append(line)
421            lines.append('')
422        return '\n'.join(lines)
423
424
425def map_k_points(bzk_kc, U_scc, time_reversal, comm=None, tol=1e-11):
426    """Find symmetry relations between k-points.
427
428    This is a Python-wrapper for a C-function that does the hard work
429    which is distributed over comm.
430
431    The map bz2bz_ks is returned.  If there is a k2 for which::
432
433      = _    _    _
434      U q  = q  + N,
435       s k1   k2
436
437    where N is a vector of integers, then bz2bz_ks[k1, s] = k2, otherwise
438    if there is a k2 for which::
439
440      = _     _    _
441      U q  = -q  + N,
442       s k1    k2
443
444    then bz2bz_ks[k1, s + nsym] = k2, where nsym = len(U_scc).  Otherwise
445    bz2bz_ks[k1, s] = -1.
446    """
447
448    if comm is None or isinstance(comm, mpi.DryRunCommunicator):
449        comm = mpi.serial_comm
450
451    nbzkpts = len(bzk_kc)
452    ka = nbzkpts * comm.rank // comm.size
453    kb = nbzkpts * (comm.rank + 1) // comm.size
454    assert comm.sum(kb - ka) == nbzkpts
455
456    if time_reversal:
457        U_scc = np.concatenate([U_scc, -U_scc])
458
459    bz2bz_ks = np.zeros((nbzkpts, len(U_scc)), int)
460    bz2bz_ks[ka:kb] = -1
461    _gpaw.map_k_points(np.ascontiguousarray(bzk_kc),
462                       np.ascontiguousarray(U_scc), tol, bz2bz_ks, ka, kb)
463    comm.sum(bz2bz_ks)
464    return bz2bz_ks
465
466
467def map_k_points_fast(bzk_kc, U_scc, time_reversal, comm=None, tol=1e-7):
468    """Find symmetry relations between k-points.
469
470    Performs the same task as map_k_points(), but much faster.
471    This is achieved by finding the symmetry related kpoints using
472    lexical sorting instead of brute force searching.
473
474    bzk_kc: ndarray
475        kpoint coordinates.
476    U_scc: ndarray
477        Symmetry operations
478    time_reversal: Bool
479        Use time reversal symmetry in mapping.
480    comm:
481        Communicator
482    tol: float
483        When kpoint are closer than tol, they are
484        considered to be identical.
485    """
486
487    nbzkpts = len(bzk_kc)
488
489    if time_reversal:
490        U_scc = np.concatenate([U_scc, -U_scc])
491
492    bz2bz_ks = np.zeros((nbzkpts, len(U_scc)), int)
493    bz2bz_ks[:] = -1
494
495    for s, U_cc in enumerate(U_scc):
496        # Find mapped kpoints
497        Ubzk_kc = np.dot(bzk_kc, U_cc.T)
498
499        # Do some work on the input
500        k_kc = np.concatenate([bzk_kc, Ubzk_kc])
501        k_kc = np.mod(np.mod(k_kc, 1), 1)
502        aglomerate_points(k_kc, tol)
503        k_kc = k_kc.round(-np.log10(tol).astype(int))
504        k_kc = np.mod(k_kc, 1)
505
506        # Find the lexicographical order
507        order = np.lexsort(k_kc.T)
508        k_kc = k_kc[order]
509        diff_kc = np.diff(k_kc, axis=0)
510        equivalentpairs_k = np.array((diff_kc == 0).all(1),
511                                     bool)
512
513        # Mapping array.
514        orders = np.array([order[:-1][equivalentpairs_k],
515                           order[1:][equivalentpairs_k]])
516
517        # This has to be true.
518        assert (orders[0] < nbzkpts).all()
519        assert (orders[1] >= nbzkpts).all()
520        bz2bz_ks[orders[1] - nbzkpts, s] = orders[0]
521
522    return bz2bz_ks
523
524
525def aglomerate_points(k_kc, tol):
526    nd = k_kc.shape[1]
527    nbzkpts = len(k_kc)
528    inds_kc = np.argsort(k_kc, axis=0)
529    for c in range(nd):
530        sk_k = k_kc[inds_kc[:, c], c]
531        dk_k = np.diff(sk_k)
532
533        # Partition the kpoints into groups
534        pt_K = np.argwhere(dk_k > tol)[:, 0]
535        pt_K = np.append(np.append(0, pt_K + 1), 2 * nbzkpts)
536        for i in range(len(pt_K) - 1):
537            k_kc[inds_kc[pt_K[i]:pt_K[i + 1], c],
538                 c] = k_kc[inds_kc[pt_K[i], c], c]
539
540
541def atoms2symmetry(atoms, id_a=None, tolerance=1e-7):
542    """Create symmetry object from atoms object."""
543    if id_a is None:
544        id_a = atoms.get_atomic_numbers()
545    symmetry = Symmetry(id_a, atoms.cell, atoms.pbc,
546                        symmorphic=False,
547                        time_reversal=False,
548                        tolerance=tolerance)
549    symmetry.analyze(atoms.get_scaled_positions())
550    return symmetry
551
552
553class CLICommand:
554    """Analyse symmetry."""
555
556    @staticmethod
557    def add_arguments(parser):
558        parser.add_argument('filename')
559
560    @staticmethod
561    def run(args):
562        atoms = read(args.filename)
563        symmetry = atoms2symmetry(atoms)
564        print(symmetry)
565