1# coding: utf-8
2"""Objects used to deal with symmetry operations in crystals."""
3import sys
4import abc
5import warnings
6import collections
7import numpy as np
8import spglib
9
10from monty.string import is_string
11from monty.itertools import iuptri
12from monty.functools import lazy_property
13from monty.termcolor import cprint
14from monty.collections import dict2namedtuple
15from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
16from pymatgen.util.serialization import SlotPickleMixin
17from abipy.core.kpoints import wrap_to_ws, issamek, has_timrev_from_kptopt
18
19
20__all__ = [
21    "LatticeRotation",
22    "AbinitSpaceGroup",
23]
24
25
26def wrap_in_ucell(x):
27    """
28    Transforms x in its corresponding reduced number in the interval [0,1[."
29    """
30    return x % 1
31
32
33def is_integer(x, atol=1e-08):
34    """
35    True if all x is integer within the absolute tolerance atol.
36
37    >>> is_integer([1., 2.])
38    True
39    >>> is_integer(1.01, atol=0.011)
40    True
41    >>> is_integer([1.01, 2])
42    False
43    """
44    int_x = np.around(x)
45    return np.allclose(int_x, x, atol=atol)
46
47
48def mati3inv(mat3, trans=True):
49    """
50    Invert and transpose orthogonal 3x3 matrix of INTEGER elements.
51
52    Args:
53        mat3: (3, 3) matrix-like object with integer elements
54
55    Returns:
56        |numpy-array| with the TRANSPOSE of the inverse of mat3 if trans==True.
57        If trans==False, the inverse of mat3 is returned.
58
59    .. note::
60
61       Used for symmetry operations. This function applies to *ORTHOGONAL* matrices only.
62       Since these form a group, inverses are also integer arrays.
63    """
64    mat3 = np.reshape(np.array(mat3, dtype=int), (3, 3))
65
66    mit = np.empty((3, 3), dtype=int)
67    mit[0,0] = mat3[1,1] * mat3[2,2] - mat3[2,1] * mat3[1,2]
68    mit[1,0] = mat3[2,1] * mat3[0,2] - mat3[0,1] * mat3[2,2]
69    mit[2,0] = mat3[0,1] * mat3[1,2] - mat3[1,1] * mat3[0,2]
70    mit[0,1] = mat3[2,0] * mat3[1,2] - mat3[1,0] * mat3[2,2]
71    mit[1,1] = mat3[0,0] * mat3[2,2] - mat3[2,0] * mat3[0,2]
72    mit[2,1] = mat3[1,0] * mat3[0,2] - mat3[0,0] * mat3[1,2]
73    mit[0,2] = mat3[1,0] * mat3[2,1] - mat3[2,0] * mat3[1,1]
74    mit[1,2] = mat3[2,0] * mat3[0,1] - mat3[0,0] * mat3[2,1]
75    mit[2,2] = mat3[0,0] * mat3[1,1] - mat3[1,0] * mat3[0,1]
76
77    dd = mat3[0,0] * mit[0,0] + mat3[1,0] * mit[1,0] + mat3[2,0] * mit[2,0]
78
79    # Make sure matrix is not singular
80    if dd == 0:
81        raise ValueError("Attempting to invert integer array: %s\n ==> determinant is zero." % str(mat3))
82
83    mit = mit // dd
84    if trans:
85        return mit
86    else:
87        return mit.T.copy()
88
89
90def _get_det(mat):
91    """
92    Return the determinant of a 3x3 rotation matrix mat.
93
94    raises:
95        ValueError if abs(det) != 1.
96    """
97    det = mat[0,0] * (mat[1,1] * mat[2,2] - mat[1,2] * mat[2,1])\
98        - mat[0,1] * (mat[1,0] * mat[2,2] - mat[1,2] * mat[2,0])\
99        + mat[0,2] * (mat[1,0] * mat[2,1] - mat[1,1] * mat[2,0])
100
101    if abs(det) != 1:
102        raise ValueError("Determinant must be +-1 while it is %s" % det)
103
104    return det
105
106
107def indsym_from_symrel(symrel, tnons, structure, tolsym=1e-8):
108    r"""
109    For each symmetry operation, find the number of the position to
110    which each atom is sent in the unit cell by the INVERSE of the
111    symmetry operation inv(symrel); i.e. this is the atom which, when acted
112    upon by the given symmetry element isym, gets transformed into atom iatom.
113    indirect indexing array for atoms, see symatm.F90.
114
115    $ R^{-1} (xred(:,iat) - \tau) = xred(:,iat_sym) + R_0 $
116    * indsym(4,  isym,iat) gives iat_sym in the original unit cell.
117    * indsym(1:3,isym,iat) gives the lattice vector $R_0$.
118
119    Args:
120        symrel: int (nsym,3,3) array with real space symmetries expressed in reduced coordinates.
121        tnons: float (nsym, 3) array with nonsymmorphic translations for each symmetry.
122        structure: |Structure| object.
123        tolsym: tolerance for the symmetries
124
125    Returns:
126    """
127    natom = len(structure)
128    nsym = len(symrel)
129    xred = np.array([site.frac_coords for site in structure], dtype=float)
130    typat = {i: site.specie.symbol for i, site in enumerate(structure)}
131
132    rm1_list = np.empty_like(symrel)
133    for isym in range(nsym):
134        rm1_list[isym] = mati3inv(symrel[isym], trans=False)
135
136    # Start testmn out at large value
137    testmn = 1000000
138    err = 0.0
139    indsym = np.empty((natom, nsym, 4))
140
141    # Implementation is similar to Abinit routine (including the order of the loops)
142    for isym in range(nsym):
143        for iatom in range(natom):
144            tratm = np.matmul(rm1_list[isym], xred[iatom] - tnons[isym])
145            # Loop through atoms, when types agree, check for agreement after primitive translation
146            for jatm in range(natom):
147                if typat[jatm] != typat[iatom]: continue
148                test_vec = tratm - xred[jatm]
149                # Find nearest integer part of difference
150                trans = np.rint(test_vec)
151                # Check whether, after translation, they agree
152                test_vec = test_vec - trans
153                diff = np.abs(test_vec).sum()
154                # Abinit uses 1e-10 but python seems to require a slightly larger value.
155                #if diff < 1e-10:
156                if diff < 1e-9:
157                    difmin = test_vec
158                    indsym[iatom, isym, :3] = trans
159                    indsym[iatom, isym, 3] = jatm
160                    # Break out of loop when agreement is within tolerance
161                    break
162                else:
163                    # Keep track of smallest difference if greater than tol10
164                    if diff < testmn:
165                        testmn = diff
166                        # Note that abs() is not taken here
167                        difmin = test_vec
168                        indsym[iatom, isym, :3] = trans
169                        indsym[iatom, isym, 3] = jatm
170
171        # Keep track of maximum difference between transformed coordinates and nearest "target" coordinate
172        difmax = np.abs(difmin).max()
173        err = max(err, difmax)
174        if difmax > tolsym:
175            cprint(f"""
176Trouble finding symmetrically equivalent atoms.
177Applying inverse of symm number {isym} to atom number {iatom} of typat',typat(iatom) gives tratom=',tratom(1:3)
178This is further away from every atom in crystal than the allowed tolerance.
179The inverse symmetry matrix is',symrec(1,1:3,isym),ch10,&
180                               ',symrec(2,1:3,isym),ch10,&
181                               ',symrec(3,1:3,isym)
182and the nonsymmorphic transl. tnons =',(tnons(mu,isym),mu=1,3)
183The nearest coordinate differs by',difmin(1:3) for indsym(nearest atom)=',indsym(4,isym,iatom)
184
185This indicates that when symatm attempts to find atoms symmetrically
186related to a given atom, the nearest candidate is further away than some tolerance.
187Should check atomic coordinates and symmetry group input data.
188""", color="red")
189
190    if err > tolsym:
191        raise ValueError("maximum err %s is larger than tolsym: %s" % (err, tolsym))
192
193    return indsym
194
195
196class Operation(metaclass=abc.ABCMeta):
197    """
198    Abstract base class that defines the methods that must be
199    implemented by the concrete class representing some sort of operation
200    """
201    @abc.abstractmethod
202    def __eq__(self, other):
203        """O1 == O2"""
204
205    def __ne__(self, other):
206        return not (self == other)
207
208    @abc.abstractmethod
209    def __mul__(self, other):
210        """O1 * O2"""
211
212    @abc.abstractmethod
213    def __hash__(self):
214        """Operation can be used as dictionary keys."""
215
216    @abc.abstractmethod
217    def inverse(self):
218        """Returns the inverse of self."""
219
220    def opconj(self, other):
221        """Returns X^-1 S X where X is the other symmetry operation."""
222        return other.inverse() * self * other
223
224    @abc.abstractproperty
225    def isE(self):
226        """True if self is the identity operator"""
227
228    #def commute(self, other)
229    #    return self * other == other * self
230
231    #def commutator(self, other)
232    #    return self * other - other * self
233
234    #def anticommute(self, other)
235    #    return self * other == - other * self
236
237    #def direct_product(self, other)
238
239
240class SymmOp(Operation, SlotPickleMixin):
241    """
242    Crystalline symmetry.
243    """
244    _ATOL_TAU = 1e-8
245
246    __slots__ = [
247        "rot_r",
248        "rotm1_r",
249        "tau",
250        "time_sign",
251        "afm_sign",
252        "rot_g",
253        "_det",
254        "_trace",
255    ]
256
257    # TODO: Add lattice?
258    def __init__(self, rot_r, tau, time_sign, afm_sign, rot_g=None):
259        """
260        This object represents a space group symmetry i.e. a symmetry of the crystal.
261
262        Args:
263            rot_r: (3,3) integer matrix with the rotational part in real space in reduced coordinates (C order).
264            tau: fractional translation in reduced coordinates.
265            time_sign: -1 if time reversal can be used, +1 otherwise.
266            afm_sign: anti-ferromagnetic part [+1, -1].
267        """
268        rot_r = np.asarray(rot_r)
269
270        # Store R and R^{-1} in real space.
271        self.rot_r, self.rotm1_r = rot_r, mati3inv(rot_r, trans=False)
272        self.tau = np.asarray(tau)
273
274        self.afm_sign, self.time_sign = afm_sign, time_sign
275        assert afm_sign in [-1, 1] and time_sign in [-1, 1]
276
277        # Compute symmetry matrix in reciprocal space: S = R^{-1t}
278        if rot_g is None:
279            self.rot_g = mati3inv(rot_r, trans=True)
280        else:
281            assert np.all(rot_g == mati3inv(rot_r, trans=True))
282            self.rot_g = rot_g
283
284    # operator protocol.
285    def __eq__(self, other):
286        # Note the two fractional traslations are equivalent if they differ by a lattice vector.
287        return (np.all(self.rot_r == other.rot_r) and
288                is_integer(self.tau - other.tau, atol=self._ATOL_TAU) and
289                self.afm_sign == other.afm_sign and
290                self.time_sign == other.time_sign)
291
292    def __mul__(self, other):
293        """
294        Returns a new :class:`SymmOp` which is equivalent to apply the "other" :class:`SymmOp`
295        followed by this one i.e:
296
297        {R,t} {S,u} = {RS, Ru + t}
298        """
299        return self.__class__(rot_r=np.dot(self.rot_r, other.rot_r),
300                              tau=self.tau + np.dot(self.rot_r, other.tau),
301                              time_sign=self.time_sign * other.time_sign,
302                              afm_sign=self.afm_sign * other.afm_sign)
303
304    def __hash__(self):
305        """
306        :class:`Symmop` can be used as keys in dictionaries.
307        Note that the hash is computed from integer values.
308        """
309        return int(8 * self.trace + 4 * self.det + 2 * self.time_sign)
310
311    def inverse(self):
312        """Returns inverse of transformation i.e. {R^{-1}, -R^{-1} tau}."""
313        return self.__class__(rot_r=self.rotm1_r,
314                              tau=-np.dot(self.rotm1_r, self.tau),
315                              time_sign=self.time_sign,
316                              afm_sign=self.afm_sign)
317
318    @lazy_property
319    def isE(self):
320        """True if identity operator."""
321        return (np.all(self.rot_r == np.eye(3, dtype=int)) and
322                is_integer(self.tau, atol=self._ATOL_TAU) and
323                self.time_sign == 1 and
324                self.afm_sign == 1)
325    # end operator protocol.
326
327    #@lazy_property
328    #def order(self):
329    #    """Order of the operation."""
330    #    n = 0
331    #    o = self
332    #    while m < 1000:
333    #        if o.isE: return n
334    #        n += 1
335    #        o = self * o
336    #    else:
337    #        raise ValueError("Cannot find order")
338
339    def __repr__(self):
340        return str(self)
341
342    def __str__(self):
343        return self.to_string()
344
345    def to_string(self, verbose=0):
346        def vec2str(vec):
347            return "%2d,%2d,%2d" % tuple(v for v in vec)
348
349        s = ""
350        for i in range(3):
351            s += "[" + vec2str(self.rot_r[i]) + ", %.3f]  " % self.tau[i] + "[" + vec2str(self.rot_g[i]) + "] "
352            if i == 2:
353                s += ", time_sign = %+1d, afm_sign = %+1d, det = %+1d" % (self.time_sign, self.afm_sign, self.det)
354            s += "\n"
355
356        return s
357
358    @lazy_property
359    def is_symmorphic(self):
360        """True if the fractional translation is non-zero."""
361        return np.any(np.abs(self.tau) > 0.0)
362
363    @lazy_property
364    def det(self):
365        """Determinant of the rotation matrix [-1, +1]."""
366        return _get_det(self.rot_r)
367
368    @lazy_property
369    def trace(self):
370        """Trace of the rotation matrix."""
371        return self.rot_r.trace()
372
373    @lazy_property
374    def is_proper(self):
375        """True if the rotational part has determinant == 1."""
376        return self.det == +1
377
378    @lazy_property
379    def has_timerev(self):
380        """True if symmetry contains the time-reversal operator."""
381        return self.time_sign == -1
382
383    @lazy_property
384    def is_fm(self):
385        """True if self if ferromagnetic symmetry."""
386        return self.afm_sign == +1
387
388    @lazy_property
389    def is_afm(self):
390        """True if self if anti-ferromagnetic symmetry."""
391        return self.afm_sign == -1
392
393    def rotate_k(self, frac_coords, wrap_tows=False):
394        """
395        Apply the symmetry operation to the k-point given in reduced coordinates.
396
397        Sk is wrapped to the first Brillouin zone if wrap_tows is True.
398        """
399        sk = np.dot(self.rot_g, frac_coords) * self.time_sign
400
401        return wrap_to_ws(sk) if wrap_tows else sk
402
403    def preserve_k(self, frac_coords, ret_g0=True):
404        """
405        Check if the operation preserves the k-point modulo a reciprocal lattice vector.
406
407        Args:
408            frac_coords: Fractional coordinates of the k-point
409            ret_g0: False if only the boolean result is wanted.
410
411        Returns:
412            bool, g0 = S(k) - k
413
414            bool is True is self preserves k and g0 is an integer vector.
415        """
416        sk = self.rotate_k(frac_coords, wrap_tows=False)
417
418        if ret_g0:
419            return issamek(sk, frac_coords), np.array(np.round(sk - frac_coords), dtype=int)
420        else:
421            return issamek(sk, frac_coords)
422
423    def rotate_r(self, frac_coords, in_ucell=False):
424        """
425        Apply the symmetry operation to a point in real space given in reduced coordinates.
426
427        .. NOTE::
428
429            We use the convention: symmop(r) = R^{-1] (r - tau)
430        """
431        rotm1_rmt = np.dot(self.rotm1_r, frac_coords - self.tau)
432
433        return wrap_in_ucell(rotm1_rmt) if in_ucell else rotm1_rmt
434
435
436class OpSequence(collections.abc.Sequence):
437    """
438    Mixin class providing the basic method that are common to containers of operations.
439    """
440    def __len__(self):
441        return len(self._ops)
442
443    def __iter__(self):
444        return self._ops.__iter__()
445
446    def __getitem__(self, slice):
447        return self._ops[slice]
448
449    def __contains__(self, op):
450        return op in self._ops
451
452    def __eq__(self, other):
453        """
454        Equality test.
455
456        .. warning::
457
458            The order of the operations in self and  in other is not relevant.
459        """
460        if other is None: return False
461        if len(self) != len(other):
462            return False
463
464        # Check if each operation in self is also present in other.
465        # The order is irrelevant.
466        founds = []
467        for i, op in enumerate(self):
468            if op not in other: return False
469            founds.append(i)
470
471        if len(set(founds)) == len(founds):
472            return True
473
474        warnings.warn("self contains duplicated ops! Likely a bug!")
475        return False
476
477    def __ne__(self, other):
478        return not (self == other)
479
480    def __str__(self):
481        lines = [str(op) for op in self]
482        return "\n".join(lines)
483
484    def show_ops(self, stream=sys.stdout):
485        lines = [str(op) for op in self]
486        stream.writelines("\n".join(lines))
487
488    def count(self, op):
489        """Returns the number of occurences of operation op in self."""
490        return self._ops.count(op)
491
492    def index(self, op):
493        """
494        Return the (first) index of operation op in self.
495
496        Raises:
497            ValueError if not found.
498        """
499        return self._ops.index(op)
500
501    def find(self, op):
502        """Return the (first) index of op in self. -1 if not found."""
503        try:
504            return self.index(op)
505        except ValueError:
506            return -1
507
508    def is_group(self):
509        """True if this set of operations represent a group."""
510        check = 0
511
512        # Identity must be present.
513        if [op.isE for op in self].count(True) != 1:
514            check += 1
515
516        # The inverse must be in the set.
517        if [op.inverse() in self for op in self].count(True) != len(self):
518            check += 2
519
520        # The product of two members must be in the set.
521        op_prods = [op1 * op2 for op1 in self for op2 in self]
522
523        d = self.asdict()
524        for op12 in op_prods:
525            if op12 not in d:
526                print("op12 not in group\n %s" % str(op12))
527                check += 1
528
529        return check == 0
530
531    def is_commutative(self):
532        """True if all operations commute with each other."""
533        for op1, op2 in iuptri(self, diago=False):
534            if op1 * op2 != op2 * op1: return False
535        return True
536
537    def is_abelian_group(self):
538        """True if commutative group."""
539        return self.is_commutative() and self.is_group()
540
541    def asdict(self):
542        """
543        Returns a dictionary where the keys are the symmetry operations and
544        the values are the indices of the operations in the iterable.
545        """
546        return {op: idx for idx, op in enumerate(self)}
547
548    #def is_subset(self, other)
549    #    indmap = {}
550    #    for i, op in self:
551    #        j = other.find(op)
552    #        if j != -1: indmap[i] = j
553    #    return indmap
554
555    #def is_superset(self, other)
556
557    @lazy_property
558    def mult_table(self):
559        """
560        Given a set of nsym 3x3 operations which are supposed to form a group,
561        this routine constructs the multiplication table of the group.
562        mtable[i,j] gives the index of the product S_i * S_j.
563        """
564        mtable = np.empty((len(self), len(self)), dtype=int)
565
566        d = self.asdict()
567        for i, op1 in enumerate(self):
568            for j, op2 in enumerate(self):
569                op12 = op1 * op2
570                # Save the index of op12 in self
571                try:
572                    index = d[op12]
573                except KeyError:
574                    index = None
575                mtable[i, j] = index
576
577        return mtable
578
579    @property
580    def num_classes(self):
581        """Number of classes."""
582        return len(self.class_indices)
583
584    @lazy_property
585    def class_indices(self):
586        """
587        A class is defined as the set of distinct elements obtained by
588        considering for each element, S, of the group all its conjugate
589        elements X^-1 S X where X ranges over all the elements of the group.
590
591        Returns:
592            Nested list l = [cls0_indices, cls1_indices, ...] where each sublist
593            contains the indices of the class. len(l) equals the number of classes.
594        """
595        found, class_indices = len(self) * [False], [[] for i in range(len(self))]
596
597        num_classes = -1
598        for ii, op1 in enumerate(self):
599            if found[ii]: continue
600            num_classes += 1
601
602            for jj, op2 in enumerate(self):
603                # Form conjugate and search it among the operations
604                # that have not been found yet.
605                op1_conj = op1.opconj(op2)
606
607                for kk, op3 in enumerate(self):
608                    if not found[kk] and op1_conj == op3:
609                        found[kk] = True
610                        class_indices[num_classes].append(kk)
611
612        class_indices = class_indices[:num_classes + 1]
613        assert sum(len(c) for c in class_indices) == len(self)
614        return class_indices
615
616    def groupby_class(self, with_inds=False):
617        """
618        Iterate over the operations grouped in symmetry classes.
619
620        Args:
621            with_inds: If True, [op0, op1, ...], [ind_op0, ind_op1, ...] is returned.
622        """
623        if with_inds:
624            for indices in self.class_indices:
625                yield [self[i] for i in indices], indices
626        else:
627            for indices in self.class_indices:
628                yield [self[i] for i in indices]
629
630
631class AbinitSpaceGroup(OpSequence):
632    """
633    Container storing the space group symmetries.
634    """
635
636    def __init__(self, spgid, symrel, tnons, symafm, has_timerev, inord="C"):
637        """
638        Args:
639            spgid (int): space group number (from 1 to 232, 0 if cannot be specified).
640            symrel: (nsym,3,3) array with the rotational part of the symmetries in real
641                space (reduced coordinates are assumed, see also `inord` for the order.
642            tnons: (nsym,3) array with fractional translation in reduced coordinates.
643            symafm: (nsym) array with +1 for Ferromagnetic symmetry and -1 for AFM
644            has_timerev: True if time-reversal symmetry is included.
645            inord: storage order of mat in symrel[:]. If inord == "F", mat.T is stored
646                as matrices are always stored in C-order. Use inord == "F" if you have
647                read symrel from an external file produced by abinit.
648
649        .. note::
650
651            All the arrays are stored in C-order. Use as_fortran_arrays to extract data
652            that can be passes to Fortran routines.
653        """
654        self.spgid = spgid
655        assert 233 > self.spgid >= 0
656
657        # Time reversal symmetry.
658        self._has_timerev = has_timerev
659        self._time_signs = [+1, -1] if self.has_timerev else [+1]
660
661        self._symrel, self._tnons, self._symafm = list(map(np.asarray, (symrel, tnons, symafm)))
662
663        if len(self.symrel) != len(self.tnons) or len(self.symrel) != len(self.symafm):
664            raise ValueError("symrel, tnons and symafm must have equal shape[0]")
665
666        inord = inord.upper()
667        assert inord in ["F", "C"]
668        if inord == "F":
669            # Fortran to C.
670            for isym in range(len(self.symrel)):
671                self._symrel[isym] = self._symrel[isym].T
672
673        self._symrec = self._symrel.copy()
674        for isym in range(len(self.symrel)):
675            self._symrec[isym] = mati3inv(self.symrel[isym], trans=True)
676
677        all_syms = []
678        for time_sign in self._time_signs:
679            for isym in range(len(self.symrel)):
680                all_syms.append(SymmOp(rot_r=self.symrel[isym],
681                                       tau=self.tnons[isym],
682                                       time_sign=time_sign,
683                                       afm_sign=self.symafm[isym],
684                                       rot_g=self.symrec[isym]))
685        self._ops = tuple(all_syms)
686
687    @classmethod
688    def from_ncreader(cls, r, inord="F"):
689        """
690        Builds the object from a netcdf reader
691        """
692        kptopt = int(r.read_value("kptopt", default=1))
693        symrel = r.read_value("reduced_symmetry_matrices")
694
695        return cls(spgid=r.read_value("space_group"),
696                   symrel=symrel,
697                   tnons=r.read_value("reduced_symmetry_translations"),
698                   symafm=r.read_value("symafm"),
699                   has_timerev=has_timrev_from_kptopt(kptopt),
700                   inord=inord)
701
702    #@classmethod
703    #def from_file(cls, ncfile, inord="F"):
704    #    """
705    #    Initialize the object from a Netcdf file.
706    #    """
707    #    from abipy.iotools import as_etsfreader
708    #    r, closeit = as_etsfreader(ncfile)
709    #    new = cls.from_ncreader(r)
710    #    if closeit:
711    #        file.close()
712
713    #    return new
714
715    @classmethod
716    def from_structure(cls, structure, has_timerev=True, symprec=1e-5, angle_tolerance=5):
717        """
718        Takes a |Structure| object. Uses spglib to perform various symmetry finding operations.
719
720        Args:
721            structure: |Structure| object.
722            has_timerev: True is time-reversal symmetry is included.
723            symprec: Tolerance for symmetry finding.
724            angle_tolerance: Angle tolerance for symmetry finding.
725
726        .. warning::
727
728            AFM symmetries are not supported.
729        """
730        # Call spglib to get the list of symmetry operations.
731        spga = SpacegroupAnalyzer(structure, symprec=symprec, angle_tolerance=angle_tolerance)
732        data = spga.get_symmetry_dataset()
733        symrel = data["rotations"]
734
735        return cls(spgid=data["number"],
736                   symrel=symrel,
737                   tnons=data["translations"],
738                   symafm=len(symrel) * [1],
739                   has_timerev=has_timerev,
740                   inord="C")
741
742    def __repr__(self):
743        return "spgid: %d, num_spatial_symmetries: %d, has_timerev: %s, symmorphic: %s" % (
744            self.spgid, self.num_spatial_symmetries, self.has_timerev, self.is_symmorphic)
745
746    def __str__(self):
747        return self.to_string()
748
749    def to_string(self, verbose=0):
750        """String representation."""
751        lines = ["spgid: %d, num_spatial_symmetries: %d, has_timerev: %s, symmorphic: %s" % (
752            self.spgid, self.num_spatial_symmetries, self.has_timerev, self.is_symmorphic)]
753        app = lines.append
754
755        if verbose > 1:
756            for op in self.symmops(time_sign=+1):
757                app(str(op))
758
759        return "\n".join(lines)
760
761    @lazy_property
762    def is_symmorphic(self):
763        """True if there's at least one operation with non-zero fractional translation."""
764        return any(op.is_symmorphic for op in self)
765
766    @property
767    def has_timerev(self):
768        """True if time-reversal symmetry is present."""
769        return self._has_timerev
770
771    @property
772    def symrel(self):
773        """
774        [nsym, 3, 3] int array with symmetries in reduced coordinates of the direct lattice.
775        """
776        return self._symrel
777
778    @property
779    def tnons(self):
780        """
781        [nsym, 3] float array with fractional translations in reduced coordinates of the direct lattice.
782        """
783        return self._tnons
784
785    @property
786    def symrec(self):
787        """
788        [nsym, 3, 3] int array with symmetries in reduced coordinates of the reciprocal lattice.
789        """
790        return self._symrec
791
792    @property
793    def symafm(self):
794        """[nsym] int array with +1 if FM or -1 if AFM symmetry."""
795        return self._symafm
796
797    @property
798    def num_spatial_symmetries(self):
799        fact = 2 if self.has_timerev else 1
800        return int(len(self) / fact)
801
802    @property
803    def afm_symmops(self):
804        """Tuple with antiferromagnetic symmetries."""
805        return self.symmops(time_sign=None, afm_sign=-1)
806
807    @property
808    def fm_symmops(self):
809        """Tuple of ferromagnetic symmetries."""
810        return self.symmops(time_sign=None, afm_sign=+1)
811
812    def symmops(self, time_sign=None, afm_sign=None):
813        """
814        Args:
815            time_sign: If specified, only symmetries with time-reversal sign time_sign are returned.
816            afm_sign: If specified, only symmetries with anti-ferromagnetic part afm_sign are returned.
817
818        returns:
819            tuple of :class:`SymmOp` instances.
820        """
821        symmops = []
822        for sym in self._ops:
823            gotit = True
824
825            if time_sign:
826                assert time_sign in self._time_signs
827                gotit = gotit and sym.time_sign == time_sign
828
829            if afm_sign:
830                assert afm_sign in [-1,+1]
831                gotit = gotit and sym.afm_sign == afm_sign
832
833            if gotit:
834                symmops.append(sym)
835
836        return tuple(symmops)
837
838    def symeq(self, k1_frac_coords, k2_frac_coords, atol=None):
839        """
840        Test whether two k-points in fractional coordinates are symmetry equivalent
841        i.e. if there's a symmetry operations TO (including time-reversal T, if present)
842        such that::
843
844            TO(k1) = k2 + G0
845
846        Return: namedtuple with::
847
848            isym: The index of the symmetry operation such that TS(k1) = k2 + G0
849                Set to -1 if k1 and k2 are not related by symmetry.
850            op: Symmetry operation.
851            g0: numpy vector.
852        """
853        for isym, sym in enumerate(self):
854            sk_coords = sym.rotate_k(k1_frac_coords, wrap_tows=False)
855            if issamek(sk_coords, k2_frac_coords, atol=atol):
856                g0 = sym.rotate_k(k1_frac_coords) - k2_frac_coords
857                return dict2namedtuple(isym=isym, op=self[isym], g0=g0)
858
859        return dict2namedtuple(isym=-1, op=None, g0=None)
860
861    def find_little_group(self, kpoint):
862        """
863        Find the little group of the kpoint.
864
865        Args:
866            kpoint: Accept vector with the reduced coordinates or :class:`Kpoint` object.
867
868        Returns:
869            :class:`LittleGroup` object.
870        """
871        if hasattr(kpoint, "frac_coords"):
872            frac_coords = kpoint.frac_coords
873        else:
874            frac_coords = np.reshape(kpoint, (3))
875
876        to_spgrp, g0vecs = [], []
877
878        # Exclude AFM operations.
879        for isym, symmop in enumerate(self.fm_symmops):
880            is_same, g0 = symmop.preserve_k(frac_coords)
881            if is_same:
882                to_spgrp.append(isym)
883                g0vecs.append(g0)
884
885        # List with the symmetry operation that preserve the kpoint.
886        k_symmops = [self[i] for i in to_spgrp]
887        return LittleGroup(kpoint, k_symmops, g0vecs)
888
889    def get_spglib_hall_number(self, symprec=1e-5):
890        """
891        Uses spglib.get_hall_number_from_symmetry to determine the hall number
892        based on the symmetry operations. Useful when the space group number
893        is not available, but the symmetries are (e.g. the DDB file)
894
895        Args:
896            symprec: distance tolerance in fractional coordinates (not the standard
897                in cartesian coordinates). See spglib docs for more details.
898
899        Returns:
900            int: the hall number.
901        """
902        return spglib.get_hall_number_from_symmetry(self.symrel, self.tnons, symprec=symprec)
903
904
905# FIXME To maintain backward compatibility.
906SpaceGroup = AbinitSpaceGroup
907
908
909class LittleGroup(OpSequence):
910
911    def __init__(self, kpoint, symmops, g0vecs):
912        """
913        k_symmops, g0vecs, indices
914
915        k_symmops is a tuple with the symmetry operations that preserve the k-point i.e. Sk = k + G0
916        g0vecs is the tuple for G0 vectors for each operation in k_symmops
917        """
918        self.kpoint = kpoint
919        self._ops = symmops
920        self.g0vecs = np.reshape(g0vecs, (-1, 3))
921        assert len(self.symmops) == len(self.g0vecs)
922
923        # Find the point group of k so that we know how to access the Bilbao database.
924        # (note that operations are in reciprocal space, afm and time_reversal are taken out
925        krots = np.array([o.rot_g for o in symmops if not o.has_timerev])
926        self.kgroup = LatticePointGroup(krots)
927
928    @lazy_property
929    def is_symmorphic(self):
930        """True if there's at least one operation with non-zero fractional translation."""
931        return any(op.is_symmorphic for op in self)
932
933    @property
934    def symmops(self):
935        return self._ops
936
937    @lazy_property
938    def on_bz_border(self):
939        """
940        True if the k-point is on the border of the BZ.
941        """
942        frac_coords = np.array(self.kpoint)
943        kreds = wrap_to_ws(frac_coords)
944        diff = np.abs(np.abs(kreds) - 0.5)
945        return np.any(diff < 1e-8)
946
947    def iter_symmop_g0(self):
948        for symmop, g0 in zip(self.symmops, self.g0vecs):
949            yield symmop, g0
950
951    def __repr__(self):
952        return "Kpoint Group: %s, Kpoint: %s" % (self.kgroup, self.kpoint)
953
954    def __str__(self):
955        return self.to_string()
956
957    def to_string(self, verbose=0):
958        """String representation of little group."""
959        lines = ["Kpoint-group: %s, Kpoint: %s, Symmorphic: %s" % (self.kgroup, self.kpoint, self.is_symmorphic)]
960        app = lines.append
961        app(" ")
962
963        # Add character_table from Bilbao database.
964        bilbao_ptgrp = bilbao_ptgroup(self.kgroup.sch_symbol)
965        app(bilbao_ptgrp.to_string(verbose=verbose))
966        app("")
967
968        # Write warning if non-symmorphic little group with k-point at zone border.
969        if self.is_symmorphic and self.on_bz_border:
970            app("WARNING: non-symmorphic little group with k at zone-border.")
971            app("Electronic states cannot be classified with this character table.")
972
973        return "\n".join(lines)
974
975    #def iter_symmop_g0_byclass(self):
976
977
978class LatticePointGroup(OpSequence):
979
980    def __init__(self, rotations):
981        rotations = np.reshape(rotations, (-1, 3, 3))
982        self._ops = [LatticeRotation(rot) for rot in rotations]
983
984        # Call spglib to get the Herm symbol.
985        # (symbol, pointgroup_number, transformation_matrix)
986        herm_symbol, ptg_num, trans_mat = spglib.get_pointgroup(rotations)
987        # Remove blanks from C string.
988        self.herm_symbol = herm_symbol.strip()
989        #print(self.herm_symbol, ptg_num, trans_mat)
990
991        if self.sch_symbol is None:
992            raise ValueError("Cannot detect point group symbol! Got sch_symbol = %s" % self.sch_symbol)
993
994    #@classmethod
995    #def from_vectors(cls, vectors)
996
997    def __repr__(self):
998        return "%s: %s, %s (%d)" % (self.__class__.__name__, self.herm_symbol, self.sch_symbol, self.spgid)
999
1000    def __str__(self):
1001        return "%s, %s (%d)" % (self.herm_symbol, self.sch_symbol, self.spgid)
1002
1003    @property
1004    def sch_symbol(self):
1005        """Schoenflies symbol"""
1006        return herm2sch(self.herm_symbol)
1007
1008    @property
1009    def spgid(self):
1010        """ID in the space group table."""
1011        return sch2spgid(self.sch_symbol)
1012
1013
1014class LatticeRotation(Operation):
1015    """
1016    This object defines a pure rotation of the lattice (proper, improper, mirror symmetry)
1017    that is a rotation which is compatible with a lattice. The rotation matrix is
1018    expressed in reduced coordinates, therefore its elements are integers.
1019
1020    See:
1021        http://xrayweb2.chem.ou.edu/notes/symmetry.html#rotation
1022
1023    .. note::
1024
1025        This object is immutable and therefore we do not inherit from |numpy-array|.
1026    """
1027    _E3D = np.identity(3,  int)
1028
1029    def __init__(self, mat):
1030        self.mat = np.asarray(mat, dtype=int)
1031        self.mat.shape = (3, 3)
1032
1033    def _find_order_and_rootinv(self):
1034        """
1035        Returns the order of the rotation and if self is a root of the inverse.
1036        """
1037        order, root_inv = None, 0
1038        for ior in range(1, 7):
1039            rn = self ** ior
1040
1041            if rn.isE:
1042                order = ior
1043                break
1044
1045            if rn.isI:
1046                root_inv = ior
1047
1048        if order is None:
1049            raise ValueError("LatticeRotation is not a root of unit!")
1050
1051        return order, root_inv
1052
1053    def __repr__(self):
1054        return self.name
1055
1056    #def __str__(self):
1057    #    lines = "Rotation: " + str(self.order) + ", versor: " + str(self.versor) + ",
1058    #    lines.append(str(self.mat))
1059    #    return "\n".join(lines)
1060
1061    # operator protocol.
1062    def __eq__(self, other):
1063        return np.allclose(self.mat, other.mat)
1064
1065    def __mul__(self, other):
1066        return self.__class__(np.matmul(self.mat, other.mat))
1067
1068    def __hash__(self):
1069        return int(8 * self.trace + 4 * self.det)
1070
1071    def inverse(self):
1072        """
1073        Invert an orthogonal 3x3 matrix of INTEGER elements.
1074        Note use of integer arithmetic. Raise ValueError if not invertible.
1075        """
1076        return self.__class__(mati3inv(self.mat, trans=False))
1077
1078    @lazy_property
1079    def isE(self):
1080        """True if it is the identity"""
1081        return np.allclose(self.mat, self._E3D)
1082    # end operator protocol.
1083
1084    # Implement the unary arithmetic operations (+, -)
1085    def __pos__(self):
1086        return self
1087
1088    def __neg__(self):
1089        return self.__class__(-self.mat)
1090
1091    def __pow__(self, intexp, modulo=1):
1092        if intexp == 0: return self.__class__(self._E3D)
1093        if intexp > 0: return self.__class__(np.linalg.matrix_power(self.mat, intexp))
1094        if intexp == -1: return self.inverse()
1095        if intexp < 0: return self.__pow__(-intexp).inverse()
1096        raise TypeError("type %s is not supported in __pow__" % type(intexp))
1097
1098    @property
1099    def order(self):
1100        """Order of the rotation."""
1101        try:
1102            return self._order
1103        except AttributeError:
1104            self._order, self._root_inv = self._find_order_and_rootinv()
1105            return self._order
1106
1107    @property
1108    def root_inv(self):
1109        try:
1110            return self._root_inv
1111        except AttributeError:
1112            self._order, self._root_inv = self._find_order_and_rootinv()
1113            return self._root_inv
1114
1115    @lazy_property
1116    def det(self):
1117        """Return the determinant of a symmetry matrix mat[3,3]. It must be +-1"""
1118        return _get_det(self.mat)
1119
1120    @lazy_property
1121    def trace(self):
1122        """The trace of the rotation matrix"""
1123        return self.mat.trace()
1124
1125    @lazy_property
1126    def is_proper(self):
1127        """True if proper rotation"""
1128        return self.det == 1
1129
1130    @lazy_property
1131    def isI(self):
1132        """True if self is the inversion operation."""
1133        return np.allclose(self.mat, -self._E3D)
1134
1135    @lazy_property
1136    def name(self):
1137        # Sign of the determinant (only if improper)
1138        name = "-" if self.det == -1 else ""
1139        name += str(self.order)
1140        # Root of inverse?
1141        name += "-" if self.root_inv != 0 else "+"
1142
1143        return name
1144
1145    #@property
1146    #def rottype(self):
1147    #    """
1148    #    Receive a 3x3 orthogonal matrix and reports its type:
1149    #        1 Identity
1150    #        2 Inversion
1151    #        3 Proper rotation of an angle <> 180 degrees
1152    #        4 Proper rotation of 180 degrees
1153    #        5 Mirror symmetry
1154    #        6 Improper rotation
1155    #    """
1156    #    # Treat identity and inversion first
1157    #    if self.isE: return 1
1158    #    if self.isI: return 2
1159    #
1160    #    if self.isproper: # Proper rotation
1161    #        t = 3 # try angle != 180
1162    #        #det180 = get_sym_det(rot + self._E3D)
1163    #        if (self + identity).det == 0: t = 4 # 180 rotation
1164    #    else:
1165    #        # Mirror symmetry or Improper rotation
1166    #        t = 6
1167    #        #detmirror = get_sym_det(rot - self._E3D)
1168    #        if (self - identity).det == 0:
1169    #            t = 5 # Mirror symmetry if an eigenvalue is 1
1170
1171    #    return t
1172
1173
1174# TODO: Need to find an easy way to map classes in internal database
1175# onto classes computed by client code when calculation has been done
1176# with non-conventional settings (spglib?)
1177class Irrep(object):
1178    """
1179    This object represents an irreducible representation.
1180
1181    .. attributes::
1182
1183        traces: all_traces[nsym]. The trace of each irrep.
1184        character: character[num_classes]
1185    """
1186    def __init__(self, name, dim, mats, class_range):
1187        """
1188        Args:
1189            name:  Name of the irreducible representation.
1190            dim: Dimension of the irreducible representation.
1191            mats: Array of shape [nsym,dim,dim] with the irreducible
1192                representations of the group. mats are packed in classes.
1193            class_range: List of tuples, each tuple gives the start and stop index for the class.
1194                e.g. [(0, 2), (2,4), (4,n)]
1195        """
1196        self.name = name
1197        self._mats = np.reshape(np.array(mats), (-1, dim, dim))
1198
1199        self.traces = [m.trace() for m in self.mats]
1200
1201        self.class_range = class_range
1202        self.nclass = len(class_range)
1203
1204        # Compute character table.
1205        character = self.nclass * [None]
1206        for icls, (start, stop) in enumerate(self.class_range):
1207            t0 = self.traces[start]
1208            isok = all(t0 == self.traces[i] for i in range(start, stop))
1209            character[icls] = t0
1210
1211        self._character = character
1212
1213    @property
1214    def mats(self):
1215        return self._mats
1216
1217    @property
1218    def character(self):
1219        return self._character
1220
1221    #@lazy_property
1222    #def dataframe(self):
1223
1224
1225def bilbao_ptgroup(sch_symbol):
1226    """
1227    Returns an instance of :class:`BilbaoPointGroup` from a string with the point group symbol
1228    or a number with the spacegroup ID.
1229    """
1230    sch_symbol = any2sch(sch_symbol)
1231
1232    from abipy.core.irrepsdb import _PTG_IRREPS_DB
1233    entry = _PTG_IRREPS_DB[sch_symbol].copy()
1234    entry.pop("nclass")
1235    entry["sch_symbol"] = sch_symbol
1236
1237    return BilbaoPointGroup(**entry)
1238
1239
1240class BilbaoPointGroup(object):
1241    """
1242    A :class:`BilbaoPointGroup` is a :class:`Pointgroup` with irreducible representations
1243    """
1244    def __init__(self, sch_symbol, rotations, class_names, class_range, irreps):
1245        # Rotations are grouped in classes.
1246        self.sch_symbol = sch_symbol
1247        self.rotations = np.reshape(rotations, (-1, 3, 3))
1248        self.class_names = class_names
1249        self.nclass = len(class_names)
1250
1251        # List of tuples, each tuple gives the start and stop index for the class.
1252        # e.g. [(0, 2), (2,4), (4,n)]
1253        self.class_range = class_range
1254        self.class_len = [stop - start for start, stop in class_range]
1255
1256        # The number of irreps must equal the number of classes.
1257        assert len(irreps) == self.nclass
1258        self.irreps, self.irreps_by_name = [], {}
1259        for name, d in irreps.items():
1260            mats = d["matrices"]
1261            assert len(mats) == self.num_rots
1262            irrep = Irrep(name, d["dim"], mats, class_range=self.class_range)
1263            self.irreps.append(irrep)
1264            self.irreps_by_name[name] = irrep
1265
1266    @property
1267    def herm_symbol(self):
1268        """Hermann-Mauguin symbol."""
1269        return herm2sch(self.sch_symbol)
1270
1271    @property
1272    def spgid(self):
1273        """ID in the space group table."""
1274        return sch2spgid(self.sch_symbol)
1275
1276    @property
1277    def num_rots(self):
1278        """Number of rotations."""
1279        return len(self.rotations)
1280
1281    @property
1282    def num_irreps(self):
1283        """Number of irreducible representations."""
1284        return len(self.irreps)
1285
1286    @property
1287    def irrep_names(self):
1288        """List with the names of the irreps."""
1289        return list(self.irreps_by_name.keys())
1290
1291    @lazy_property
1292    def character_table(self):
1293        """
1294        Dataframe with irreps.
1295        """
1296        # Caveat: class names are not necessarly unique --> use np.stack
1297        import pandas as pd
1298        name_mult = [name + " [" + str(mult) + "]" for (name, mult) in zip(self.class_names, self.class_len)]
1299        columns = ["name"] + name_mult
1300
1301        stack = np.stack([irrep.character for irrep in self.irreps])
1302        index = [irrep.name for irrep in self.irreps]
1303        df = pd.DataFrame(stack, columns=name_mult, index=index)
1304        df.index.name = "Irrep"
1305        df.columns.name = self.sch_symbol
1306
1307        # TODO
1308        #print(df)
1309        # Convert complex --> real if all entries in a colums are real.
1310        #for k in name_mult:
1311        #    if np.all(np.isreal(df[k].values)):
1312        #        #df[k] = df[k].values.real
1313        #        df[k] = df[k].astype(float)
1314
1315        return df
1316
1317    def to_string(self, verbose=0):
1318        """
1319        Return string with the character_table
1320        """
1321        return self.character_table.to_string()
1322
1323    #def decompose(self, character):
1324    #   od = collections.OrderedDict()
1325    #   for irrep in self.irreps:
1326    #       irrep.name
1327    #       irrep.character
1328    #   return od
1329
1330    #def show_irrep(self, irrep_name):
1331    #    """Show the mapping rotation --> irrep mat."""
1332    #    irrep = self.irreps_by_name[irrep_name]
1333
1334    #def irrep_from_character(self, character, rotations, tol=None):
1335    #    """
1336    #    Main entry point for client code.
1337    #    This routine receives a character computed from the user and finds the
1338    #    irreducible representation.
1339    #    """
1340
1341    #def map_rotclasses(self, rotations_in_classes)
1342    #def map_rotation(self, rotations_in_classes)
1343
1344    def auto_test(self):
1345        """
1346        Perform internal consistency check. Return 0 if success
1347        """
1348        #print("rotations\n", self.rotations)
1349        rot_group = LatticePointGroup(self.rotations)
1350        if not rot_group.is_group():
1351            print("rotations do not form a group!")
1352            return 1
1353
1354        # Symmetries should be ordered in classes.
1355        # Here we recompute the classes by calling rot_group.class_indices.
1356        # We then sort the indices and we compare the results with the ref data stored in the Bilbao database.
1357        calc_class_inds = [sorted(l) for l in rot_group.class_indices]
1358        #print(calc_class_inds)
1359        assert len(calc_class_inds) == len(self.class_range)
1360
1361        for calc_inds, ref_range in zip(calc_class_inds, self.class_range):
1362            ref_inds = list(range(ref_range[0], ref_range[1]))
1363            if calc_inds != ref_inds:
1364                print("Rotations are not ordered in classes.", calc_inds, ref_inds)
1365                return 2
1366
1367        # Do we have a representation of the Group?
1368        mult_table = rot_group.mult_table
1369        max_err = 0.0
1370
1371        for idx1, rot1 in enumerate(rot_group):
1372            for idx2, rot2 in enumerate(rot_group):
1373                idx_prod = mult_table[idx1, idx2]
1374                for irrep in self.irreps:
1375                    mat_prod = np.dot(irrep.mats[idx1], irrep.mats[idx2])
1376                    err = (mat_prod - irrep.mats[idx_prod]).max()
1377                    max_err = max(max_err, abs(err))
1378
1379        if max_err > 1e-5:
1380            print("Irreps do not form a representation of the group, max_err: ", max_err)
1381            return 3
1382
1383        # TODO
1384        # Test orthogonality theorem
1385
1386        # Test the orthogonality relation of traces.
1387        max_err = 0.0
1388        for (ii, jj), (irp1, irp2) in iuptri(self.irreps, with_inds=True):
1389            trac1, trac2 = irp1.traces, irp2.traces
1390            err = np.vdot(trac1, trac2) / self.num_rots
1391            if ii == jj: err -= 1.0
1392            max_err = max(max_err, abs(err))
1393
1394        if max_err > 1e-5:
1395            print("Error in orthogonality relation of traces: ", max_err)
1396            return 4
1397
1398        # Success.
1399        return 0
1400
1401
1402# Schoenflies, Hermann-Mauguin, spgid
1403_PTG_IDS = [
1404    ("C1" , "1",     1),
1405    ("Ci" , "-1",    2),
1406    ("C2" , "2",     3),
1407    ("Cs" , "m",     6),
1408    ("C2h", "2/m",   10),
1409    ("D2" , "222",   16),
1410    ("C2v", "mm2",   25),
1411    ("D2h", "mmm",   47),
1412    ("C4" , "4",     75),
1413    ("S4" , "-4",    81),
1414    ("C4h", "4/m",   83),
1415    ("D4" , "422",   89),
1416    ("C4v", "4mm",   99),
1417    ("D2d", "-42m",  111),
1418    ("D4h", "4/mmm", 123),
1419    ("C3" , "3",     143),
1420    ("C3i", "-3",    147),
1421    ("D3" , "32",    149),
1422    ("C3v", "3m",    156),
1423    ("D3d", "-3m",   162),
1424    ("C6" , "6",     168),
1425    ("C3h", "-6",    174),
1426    ("C6h", "6/m",   175),
1427    ("D6" , "622",   177),
1428    ("C6v", "6mm",   183),
1429    ("D3h", "-6m2",  189),
1430    ("D6h", "6/mmm", 191),
1431    ("T"  , "23",    195),
1432    ("Th" , "m-3",   200),
1433    ("O"  , "432",   207),
1434    ("Td" , "-43m",  215),
1435    ("Oh" , "m-3m",  221),
1436]
1437
1438_SCH2HERM = {t[0]: t[1] for t in _PTG_IDS}
1439_HERM2SCH = {t[1]: t[0] for t in _PTG_IDS}
1440_SPGID2SCH = {t[2]: t[0] for t in _PTG_IDS}
1441_SCH2SPGID = {t[0]: t[2] for t in _PTG_IDS}
1442
1443sch_symbols = list(_SCH2HERM.keys())
1444
1445
1446def sch2herm(sch_symbol):
1447    """Convert from Schoenflies to Hermann-Mauguin."""
1448    return _SCH2HERM.get(sch_symbol, None)
1449
1450
1451def sch2spgid(sch_symbol):
1452    """Convert from Schoenflies to the space group id."""
1453    return _SCH2SPGID.get(sch_symbol, None)
1454
1455
1456def herm2sch(herm_symbol):
1457    """Convert from Hermann-Mauguin to Schoenflies."""
1458    return _HERM2SCH.get(herm_symbol, None)
1459
1460
1461def spgid2sch(spgid):
1462    """Return the Schoenflies symbol from the space group identifier."""
1463    return _SPGID2SCH.get(spgid, None)
1464
1465
1466def any2sch(obj):
1467    """Convert string or int to Schoenflies symbol. Returns None if invalid input"""
1468    if is_string(obj):
1469        if obj in sch_symbols:
1470            return obj
1471        else:
1472            # Try Hermann-Mauguin
1473            return herm2sch(obj)
1474    else:
1475        # Spacegroup ID?
1476        return spgid2sch(obj)
1477