1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4
5"""
6An interface to the excellent spglib library by Atsushi Togo
7(http://spglib.sourceforge.net/) for pymatgen.
8
9v1.0 - Now works with both ordered and disordered structure.
10v2.0 - Updated for spglib 1.6.
11v3.0 - pymatgen no longer ships with spglib. Instead, spglib (the python
12       version) is now a dependency and the SpacegroupAnalyzer merely serves
13       as an interface to spglib for pymatgen Structures.
14"""
15
16import copy
17import itertools
18import logging
19import math
20from collections import defaultdict
21from fractions import Fraction
22from math import cos, sin
23
24import numpy as np
25import spglib
26
27from pymatgen.core.lattice import Lattice
28from pymatgen.core.operations import SymmOp
29from pymatgen.core.structure import Molecule, PeriodicSite, Structure
30from pymatgen.symmetry.structure import SymmetrizedStructure
31from pymatgen.util.coord import find_in_coord_list, pbc_diff
32
33logger = logging.getLogger(__name__)
34
35
36class SpacegroupAnalyzer:
37    """
38    Takes a pymatgen.core.structure.Structure object and a symprec.
39    Uses spglib to perform various symmetry finding operations.
40    """
41
42    def __init__(self, structure, symprec=0.01, angle_tolerance=5.0):
43        """
44        Args:
45            structure (Structure/IStructure): Structure to find symmetry
46            symprec (float): Tolerance for symmetry finding. Defaults to 0.01,
47                which is fairly strict and works well for properly refined
48                structures with atoms in the proper symmetry coordinates. For
49                structures with slight deviations from their proper atomic
50                positions (e.g., structures relaxed with electronic structure
51                codes), a looser tolerance of 0.1 (the value used in Materials
52                Project) is often needed.
53            angle_tolerance (float): Angle tolerance for symmetry finding.
54        """
55        self._symprec = symprec
56        self._angle_tol = angle_tolerance
57        self._structure = structure
58        latt = structure.lattice.matrix
59        positions = structure.frac_coords
60        unique_species = []
61        zs = []
62        magmoms = []
63
64        for species, g in itertools.groupby(structure, key=lambda s: s.species):
65            if species in unique_species:
66                ind = unique_species.index(species)
67                zs.extend([ind + 1] * len(tuple(g)))
68            else:
69                unique_species.append(species)
70                zs.extend([len(unique_species)] * len(tuple(g)))
71
72        for site in structure:
73            if hasattr(site, "magmom"):
74                magmoms.append(site.magmom)
75            elif site.is_ordered and hasattr(site.specie, "spin"):
76                magmoms.append(site.specie.spin)
77            else:
78                magmoms.append(0)
79
80        self._unique_species = unique_species
81        self._numbers = zs
82        # For now, we are setting magmom to zero.
83        self._cell = latt, positions, zs, magmoms
84
85        self._space_group_data = spglib.get_symmetry_dataset(
86            self._cell, symprec=self._symprec, angle_tolerance=angle_tolerance
87        )
88
89    def get_space_group_symbol(self):
90        """
91        Get the spacegroup symbol (e.g., Pnma) for structure.
92
93        Returns:
94            (str): Spacegroup symbol for structure.
95        """
96        return self._space_group_data["international"]
97
98    def get_space_group_number(self):
99        """
100        Get the international spacegroup number (e.g., 62) for structure.
101
102        Returns:
103            (int): International spacegroup number for structure.
104        """
105        return int(self._space_group_data["number"])
106
107    def get_space_group_operations(self):
108        """
109        Get the SpacegroupOperations for the Structure.
110
111        Returns:
112            SpacgroupOperations object.
113        """
114        return SpacegroupOperations(
115            self.get_space_group_symbol(),
116            self.get_space_group_number(),
117            self.get_symmetry_operations(),
118        )
119
120    def get_hall(self):
121        """
122        Returns Hall symbol for structure.
123
124        Returns:
125            (str): Hall symbol
126        """
127        return self._space_group_data["hall"]
128
129    def get_point_group_symbol(self):
130        """
131        Get the point group associated with the structure.
132
133        Returns:
134            (Pointgroup): Point group for structure.
135        """
136        rotations = self._space_group_data["rotations"]
137        # passing a 0-length rotations list to spglib can segfault
138        if len(rotations) == 0:
139            return "1"
140        return spglib.get_pointgroup(rotations)[0].strip()
141
142    def get_crystal_system(self):
143        """
144        Get the crystal system for the structure, e.g., (triclinic,
145        orthorhombic, cubic, etc.).
146
147        Returns:
148            (str): Crystal system for structure or None if system cannot be detected.
149        """
150        n = self._space_group_data["number"]
151
152        if 0 < n < 3:
153            return "triclinic"
154        if n < 16:
155            return "monoclinic"
156        if n < 75:
157            return "orthorhombic"
158        if n < 143:
159            return "tetragonal"
160        if n < 168:
161            return "trigonal"
162        if n < 195:
163            return "hexagonal"
164        if n < 231:
165            return "cubic"
166
167        raise ValueError("Invalid space group")
168
169    def get_lattice_type(self):
170        """
171        Get the lattice for the structure, e.g., (triclinic,
172        orthorhombic, cubic, etc.).This is the same than the
173        crystal system with the exception of the hexagonal/rhombohedral
174        lattice
175
176        Returns:
177            (str): Lattice type for structure or None if type cannot be detected.
178        """
179        n = self._space_group_data["number"]
180        system = self.get_crystal_system()
181        if n in [146, 148, 155, 160, 161, 166, 167]:
182            return "rhombohedral"
183        if system == "trigonal":
184            return "hexagonal"
185        return system
186
187    def get_symmetry_dataset(self):
188        """
189        Returns the symmetry dataset as a dict.
190
191        Returns:
192            (dict): With the following properties:
193            number: International space group number
194            international: International symbol
195            hall: Hall symbol
196            transformation_matrix: Transformation matrix from lattice of
197            input cell to Bravais lattice L^bravais = L^original * Tmat
198            origin shift: Origin shift in the setting of "Bravais lattice"
199            rotations, translations: Rotation matrices and translation
200            vectors. Space group operations are obtained by
201            [(r,t) for r, t in zip(rotations, translations)]
202            wyckoffs: Wyckoff letters
203        """
204        return self._space_group_data
205
206    def _get_symmetry(self):
207        """
208        Get the symmetry operations associated with the structure.
209
210        Returns:
211            Symmetry operations as a tuple of two equal length sequences.
212            (rotations, translations). "rotations" is the numpy integer array
213            of the rotation matrices for scaled positions
214            "translations" gives the numpy float64 array of the translation
215            vectors in scaled positions.
216        """
217        d = spglib.get_symmetry(self._cell, symprec=self._symprec, angle_tolerance=self._angle_tol)
218        # Sometimes spglib returns small translation vectors, e.g.
219        # [1e-4, 2e-4, 1e-4]
220        # (these are in fractional coordinates, so should be small denominator
221        # fractions)
222        trans = []
223        for t in d["translations"]:
224            trans.append([float(Fraction.from_float(c).limit_denominator(1000)) for c in t])
225        trans = np.array(trans)
226
227        # fractional translations of 1 are more simply 0
228        trans[np.abs(trans) == 1] = 0
229        return d["rotations"], trans
230
231    def get_symmetry_operations(self, cartesian=False):
232        """
233        Return symmetry operations as a list of SymmOp objects.
234        By default returns fractional coord symmops.
235        But cartesian can be returned too.
236
237        Returns:
238            ([SymmOp]): List of symmetry operations.
239        """
240        rotation, translation = self._get_symmetry()
241        symmops = []
242        mat = self._structure.lattice.matrix.T
243        invmat = np.linalg.inv(mat)
244        for rot, trans in zip(rotation, translation):
245            if cartesian:
246                rot = np.dot(mat, np.dot(rot, invmat))
247                trans = np.dot(trans, self._structure.lattice.matrix)
248            op = SymmOp.from_rotation_and_translation(rot, trans)
249            symmops.append(op)
250        return symmops
251
252    def get_point_group_operations(self, cartesian=False):
253        """
254        Return symmetry operations as a list of SymmOp objects.
255        By default returns fractional coord symmops.
256        But cartesian can be returned too.
257
258        Args:
259            cartesian (bool): Whether to return SymmOps as cartesian or
260                direct coordinate operations.
261
262        Returns:
263            ([SymmOp]): List of point group symmetry operations.
264        """
265        rotation, translation = self._get_symmetry()
266        symmops = []
267        mat = self._structure.lattice.matrix.T
268        invmat = np.linalg.inv(mat)
269        for rot in rotation:
270            if cartesian:
271                rot = np.dot(mat, np.dot(rot, invmat))
272            op = SymmOp.from_rotation_and_translation(rot, np.array([0, 0, 0]))
273            symmops.append(op)
274        return symmops
275
276    def get_symmetrized_structure(self):
277        """
278        Get a symmetrized structure. A symmetrized structure is one where the
279        sites have been grouped into symmetrically equivalent groups.
280
281        Returns:
282            :class:`pymatgen.symmetry.structure.SymmetrizedStructure` object.
283        """
284        ds = self.get_symmetry_dataset()
285        sg = SpacegroupOperations(
286            self.get_space_group_symbol(),
287            self.get_space_group_number(),
288            self.get_symmetry_operations(),
289        )
290        return SymmetrizedStructure(self._structure, sg, ds["equivalent_atoms"], ds["wyckoffs"])
291
292    def get_refined_structure(self):
293        """
294        Get the refined structure based on detected symmetry. The refined
295        structure is a *conventional* cell setting with atoms moved to the
296        expected symmetry positions.
297
298        Returns:
299            Refined structure.
300        """
301        # Atomic positions have to be specified by scaled positions for spglib.
302        lattice, scaled_positions, numbers = spglib.refine_cell(self._cell, self._symprec, self._angle_tol)
303
304        species = [self._unique_species[i - 1] for i in numbers]
305        s = Structure(lattice, species, scaled_positions)
306        return s.get_sorted_structure()
307
308    def find_primitive(self):
309        """
310        Find a primitive version of the unit cell.
311
312        Returns:
313            A primitive cell in the input cell is searched and returned
314            as an Structure object. If no primitive cell is found, None is
315            returned.
316        """
317        lattice, scaled_positions, numbers = spglib.find_primitive(self._cell, symprec=self._symprec)
318
319        species = [self._unique_species[i - 1] for i in numbers]
320
321        return Structure(lattice, species, scaled_positions, to_unit_cell=True).get_reduced_structure()
322
323    def get_ir_reciprocal_mesh(self, mesh=(10, 10, 10), is_shift=(0, 0, 0)):
324        """
325        k-point mesh of the Brillouin zone generated taken into account
326        symmetry.The method returns the irreducible kpoints of the mesh
327        and their weights
328
329        Args:
330            mesh (3x1 array): The number of kpoint for the mesh needed in
331                each direction
332            is_shift (3x1 array): Whether to shift the kpoint grid. (1, 1,
333            1) means all points are shifted by 0.5, 0.5, 0.5.
334
335        Returns:
336            A list of irreducible kpoints and their weights as a list of
337            tuples [(ir_kpoint, weight)], with ir_kpoint given
338            in fractional coordinates
339        """
340        shift = np.array([1 if i else 0 for i in is_shift])
341        mapping, grid = spglib.get_ir_reciprocal_mesh(np.array(mesh), self._cell, is_shift=shift, symprec=self._symprec)
342
343        results = []
344        for i, count in zip(*np.unique(mapping, return_counts=True)):
345            results.append(((grid[i] + shift * (0.5, 0.5, 0.5)) / mesh, count))
346        return results
347
348    def get_conventional_to_primitive_transformation_matrix(self, international_monoclinic=True):
349        """
350        Gives the transformation matrix to transform a conventional
351        unit cell to a primitive cell according to certain standards
352        the standards are defined in Setyawan, W., & Curtarolo, S. (2010).
353        High-throughput electronic band structure calculations:
354        Challenges and tools. Computational Materials Science,
355        49(2), 299-312. doi:10.1016/j.commatsci.2010.05.010
356
357        Returns:
358            Transformation matrix to go from conventional to primitive cell
359        """
360        conv = self.get_conventional_standard_structure(international_monoclinic=international_monoclinic)
361        lattice = self.get_lattice_type()
362
363        if "P" in self.get_space_group_symbol() or lattice == "hexagonal":
364            return np.eye(3)
365
366        if lattice == "rhombohedral":
367            # check if the conventional representation is hexagonal or
368            # rhombohedral
369            lengths = conv.lattice.lengths
370            if abs(lengths[0] - lengths[2]) < 0.0001:
371                transf = np.eye
372            else:
373                transf = np.array([[-1, 1, 1], [2, 1, 1], [-1, -2, 1]], dtype=np.float_) / 3
374
375        elif "I" in self.get_space_group_symbol():
376            transf = np.array([[-1, 1, 1], [1, -1, 1], [1, 1, -1]], dtype=np.float_) / 2
377        elif "F" in self.get_space_group_symbol():
378            transf = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=np.float_) / 2
379        elif "C" in self.get_space_group_symbol() or "A" in self.get_space_group_symbol():
380            if self.get_crystal_system() == "monoclinic":
381                transf = np.array([[1, 1, 0], [-1, 1, 0], [0, 0, 2]], dtype=np.float_) / 2
382            else:
383                transf = np.array([[1, -1, 0], [1, 1, 0], [0, 0, 2]], dtype=np.float_) / 2
384        else:
385            transf = np.eye(3)
386
387        return transf
388
389    def get_primitive_standard_structure(self, international_monoclinic=True):
390        """
391        Gives a structure with a primitive cell according to certain standards
392        the standards are defined in Setyawan, W., & Curtarolo, S. (2010).
393        High-throughput electronic band structure calculations:
394        Challenges and tools. Computational Materials Science,
395        49(2), 299-312. doi:10.1016/j.commatsci.2010.05.010
396
397        Returns:
398            The structure in a primitive standardized cell
399        """
400        conv = self.get_conventional_standard_structure(international_monoclinic=international_monoclinic)
401        lattice = self.get_lattice_type()
402
403        if "P" in self.get_space_group_symbol() or lattice == "hexagonal":
404            return conv
405
406        transf = self.get_conventional_to_primitive_transformation_matrix(
407            international_monoclinic=international_monoclinic
408        )
409
410        new_sites = []
411        latt = Lattice(np.dot(transf, conv.lattice.matrix))
412        for s in conv:
413            new_s = PeriodicSite(
414                s.specie,
415                s.coords,
416                latt,
417                to_unit_cell=True,
418                coords_are_cartesian=True,
419                properties=s.properties,
420            )
421            if not any(map(new_s.is_periodic_image, new_sites)):
422                new_sites.append(new_s)
423
424        if lattice == "rhombohedral":
425            prim = Structure.from_sites(new_sites)
426            lengths = prim.lattice.lengths
427            angles = prim.lattice.angles
428            a = lengths[0]
429            alpha = math.pi * angles[0] / 180
430            new_matrix = [
431                [a * cos(alpha / 2), -a * sin(alpha / 2), 0],
432                [a * cos(alpha / 2), a * sin(alpha / 2), 0],
433                [
434                    a * cos(alpha) / cos(alpha / 2),
435                    0,
436                    a * math.sqrt(1 - (cos(alpha) ** 2 / (cos(alpha / 2) ** 2))),
437                ],
438            ]
439            new_sites = []
440            latt = Lattice(new_matrix)
441            for s in prim:
442                new_s = PeriodicSite(
443                    s.specie,
444                    s.frac_coords,
445                    latt,
446                    to_unit_cell=True,
447                    properties=s.properties,
448                )
449                if not any(map(new_s.is_periodic_image, new_sites)):
450                    new_sites.append(new_s)
451            return Structure.from_sites(new_sites)
452
453        return Structure.from_sites(new_sites)
454
455    def get_conventional_standard_structure(self, international_monoclinic=True):
456        """
457        Gives a structure with a conventional cell according to certain
458        standards. The standards are defined in Setyawan, W., & Curtarolo,
459        S. (2010). High-throughput electronic band structure calculations:
460        Challenges and tools. Computational Materials Science,
461        49(2), 299-312. doi:10.1016/j.commatsci.2010.05.010
462        They basically enforce as much as possible
463        norm(a1)<norm(a2)<norm(a3). NB This is not necessarily the same as the
464        standard settings within the International Tables of Crystallography,
465        for which get_refined_structure should be used instead.
466
467        Returns:
468            The structure in a conventional standardized cell
469        """
470        tol = 1e-5
471        struct = self.get_refined_structure()
472        latt = struct.lattice
473        latt_type = self.get_lattice_type()
474        sorted_lengths = sorted(latt.abc)
475        sorted_dic = sorted(
476            [{"vec": latt.matrix[i], "length": latt.abc[i], "orig_index": i} for i in [0, 1, 2]],
477            key=lambda k: k["length"],
478        )
479
480        if latt_type in ("orthorhombic", "cubic"):
481            # you want to keep the c axis where it is
482            # to keep the C- settings
483            transf = np.zeros(shape=(3, 3))
484            if self.get_space_group_symbol().startswith("C"):
485                transf[2] = [0, 0, 1]
486                a, b = sorted(latt.abc[:2])
487                sorted_dic = sorted(
488                    [{"vec": latt.matrix[i], "length": latt.abc[i], "orig_index": i} for i in [0, 1]],
489                    key=lambda k: k["length"],
490                )
491                for i in range(2):
492                    transf[i][sorted_dic[i]["orig_index"]] = 1
493                c = latt.abc[2]
494            elif self.get_space_group_symbol().startswith(
495                "A"
496            ):  # change to C-centering to match Setyawan/Curtarolo convention
497                transf[2] = [1, 0, 0]
498                a, b = sorted(latt.abc[1:])
499                sorted_dic = sorted(
500                    [{"vec": latt.matrix[i], "length": latt.abc[i], "orig_index": i} for i in [1, 2]],
501                    key=lambda k: k["length"],
502                )
503                for i in range(2):
504                    transf[i][sorted_dic[i]["orig_index"]] = 1
505                c = latt.abc[0]
506            else:
507                for i, d in enumerate(sorted_dic):
508                    transf[i][d["orig_index"]] = 1
509                a, b, c = sorted_lengths
510            latt = Lattice.orthorhombic(a, b, c)
511
512        elif latt_type == "tetragonal":
513            # find the "a" vectors
514            # it is basically the vector repeated two times
515            transf = np.zeros(shape=(3, 3))
516            a, b, c = sorted_lengths
517            for i, d in enumerate(sorted_dic):
518                transf[i][d["orig_index"]] = 1
519
520            if abs(b - c) < tol < abs(a - c):
521                a, c = c, a
522                transf = np.dot([[0, 0, 1], [0, 1, 0], [1, 0, 0]], transf)
523            latt = Lattice.tetragonal(a, c)
524        elif latt_type in ("hexagonal", "rhombohedral"):
525            # for the conventional cell representation,
526            # we allways show the rhombohedral lattices as hexagonal
527
528            # check first if we have the refined structure shows a rhombohedral
529            # cell
530            # if so, make a supercell
531            a, b, c = latt.abc
532            if np.all(np.abs([a - b, c - b, a - c]) < 0.001):
533                struct.make_supercell(((1, -1, 0), (0, 1, -1), (1, 1, 1)))
534                a, b, c = sorted(struct.lattice.abc)
535
536            if abs(b - c) < 0.001:
537                a, c = c, a
538            new_matrix = [
539                [a / 2, -a * math.sqrt(3) / 2, 0],
540                [a / 2, a * math.sqrt(3) / 2, 0],
541                [0, 0, c],
542            ]
543            latt = Lattice(new_matrix)
544            transf = np.eye(3, 3)
545
546        elif latt_type == "monoclinic":
547            # You want to keep the c axis where it is to keep the C- settings
548
549            if self.get_space_group_operations().int_symbol.startswith("C"):
550                transf = np.zeros(shape=(3, 3))
551                transf[2] = [0, 0, 1]
552                sorted_dic = sorted(
553                    [{"vec": latt.matrix[i], "length": latt.abc[i], "orig_index": i} for i in [0, 1]],
554                    key=lambda k: k["length"],
555                )
556                a = sorted_dic[0]["length"]
557                b = sorted_dic[1]["length"]
558                c = latt.abc[2]
559                new_matrix = None
560                for t in itertools.permutations(list(range(2)), 2):
561                    m = latt.matrix
562                    latt2 = Lattice([m[t[0]], m[t[1]], m[2]])
563                    lengths = latt2.lengths
564                    angles = latt2.angles
565                    if angles[0] > 90:
566                        # if the angle is > 90 we invert a and b to get
567                        # an angle < 90
568                        a, b, c, alpha, beta, gamma = Lattice([-m[t[0]], -m[t[1]], m[2]]).parameters
569                        transf = np.zeros(shape=(3, 3))
570                        transf[0][t[0]] = -1
571                        transf[1][t[1]] = -1
572                        transf[2][2] = 1
573                        alpha = math.pi * alpha / 180
574                        new_matrix = [
575                            [a, 0, 0],
576                            [0, b, 0],
577                            [0, c * cos(alpha), c * sin(alpha)],
578                        ]
579                        continue
580
581                    if angles[0] < 90:
582                        transf = np.zeros(shape=(3, 3))
583                        transf[0][t[0]] = 1
584                        transf[1][t[1]] = 1
585                        transf[2][2] = 1
586                        a, b, c = lengths
587                        alpha = math.pi * angles[0] / 180
588                        new_matrix = [
589                            [a, 0, 0],
590                            [0, b, 0],
591                            [0, c * cos(alpha), c * sin(alpha)],
592                        ]
593
594                if new_matrix is None:
595                    # this if is to treat the case
596                    # where alpha==90 (but we still have a monoclinic sg
597                    new_matrix = [[a, 0, 0], [0, b, 0], [0, 0, c]]
598                    transf = np.zeros(shape=(3, 3))
599                    transf[2] = [0, 0, 1]  # see issue #1929
600                    for i, d in enumerate(sorted_dic):
601                        transf[i][d["orig_index"]] = 1
602            # if not C-setting
603            else:
604                # try all permutations of the axis
605                # keep the ones with the non-90 angle=alpha
606                # and b<c
607                new_matrix = None
608                for t in itertools.permutations(list(range(3)), 3):
609                    m = latt.matrix
610                    a, b, c, alpha, beta, gamma = Lattice([m[t[0]], m[t[1]], m[t[2]]]).parameters
611                    if alpha > 90 and b < c:
612                        a, b, c, alpha, beta, gamma = Lattice([-m[t[0]], -m[t[1]], m[t[2]]]).parameters
613                        transf = np.zeros(shape=(3, 3))
614                        transf[0][t[0]] = -1
615                        transf[1][t[1]] = -1
616                        transf[2][t[2]] = 1
617                        alpha = math.pi * alpha / 180
618                        new_matrix = [
619                            [a, 0, 0],
620                            [0, b, 0],
621                            [0, c * cos(alpha), c * sin(alpha)],
622                        ]
623                        continue
624                    if alpha < 90 and b < c:
625                        transf = np.zeros(shape=(3, 3))
626                        transf[0][t[0]] = 1
627                        transf[1][t[1]] = 1
628                        transf[2][t[2]] = 1
629                        alpha = math.pi * alpha / 180
630                        new_matrix = [
631                            [a, 0, 0],
632                            [0, b, 0],
633                            [0, c * cos(alpha), c * sin(alpha)],
634                        ]
635                if new_matrix is None:
636                    # this if is to treat the case
637                    # where alpha==90 (but we still have a monoclinic sg
638                    new_matrix = [
639                        [sorted_lengths[0], 0, 0],
640                        [0, sorted_lengths[1], 0],
641                        [0, 0, sorted_lengths[2]],
642                    ]
643                    transf = np.zeros(shape=(3, 3))
644                    for i, d in enumerate(sorted_dic):
645                        transf[i][d["orig_index"]] = 1
646
647            if international_monoclinic:
648                # The above code makes alpha the non-right angle.
649                # The following will convert to proper international convention
650                # that beta is the non-right angle.
651                op = [[0, 1, 0], [1, 0, 0], [0, 0, -1]]
652                transf = np.dot(op, transf)
653                new_matrix = np.dot(op, new_matrix)
654                beta = Lattice(new_matrix).beta
655                if beta < 90:
656                    op = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]]
657                    transf = np.dot(op, transf)
658                    new_matrix = np.dot(op, new_matrix)
659
660            latt = Lattice(new_matrix)
661
662        elif latt_type == "triclinic":
663            # we use a LLL Minkowski-like reduction for the triclinic cells
664            struct = struct.get_reduced_structure("LLL")
665            latt = struct.lattice
666
667            a, b, c = latt.lengths
668            alpha, beta, gamma = [math.pi * i / 180 for i in latt.angles]
669            new_matrix = None
670            test_matrix = [
671                [a, 0, 0],
672                [b * cos(gamma), b * sin(gamma), 0.0],
673                [
674                    c * cos(beta),
675                    c * (cos(alpha) - cos(beta) * cos(gamma)) / sin(gamma),
676                    c
677                    * math.sqrt(
678                        sin(gamma) ** 2 - cos(alpha) ** 2 - cos(beta) ** 2 + 2 * cos(alpha) * cos(beta) * cos(gamma)
679                    )
680                    / sin(gamma),
681                ],
682            ]
683
684            def is_all_acute_or_obtuse(m):
685                recp_angles = np.array(Lattice(m).reciprocal_lattice.angles)
686                return np.all(recp_angles <= 90) or np.all(recp_angles > 90)
687
688            if is_all_acute_or_obtuse(test_matrix):
689                transf = np.eye(3)
690                new_matrix = test_matrix
691
692            test_matrix = [
693                [-a, 0, 0],
694                [b * cos(gamma), b * sin(gamma), 0.0],
695                [
696                    -c * cos(beta),
697                    -c * (cos(alpha) - cos(beta) * cos(gamma)) / sin(gamma),
698                    -c
699                    * math.sqrt(
700                        sin(gamma) ** 2 - cos(alpha) ** 2 - cos(beta) ** 2 + 2 * cos(alpha) * cos(beta) * cos(gamma)
701                    )
702                    / sin(gamma),
703                ],
704            ]
705
706            if is_all_acute_or_obtuse(test_matrix):
707                transf = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]]
708                new_matrix = test_matrix
709
710            test_matrix = [
711                [-a, 0, 0],
712                [-b * cos(gamma), -b * sin(gamma), 0.0],
713                [
714                    c * cos(beta),
715                    c * (cos(alpha) - cos(beta) * cos(gamma)) / sin(gamma),
716                    c
717                    * math.sqrt(
718                        sin(gamma) ** 2 - cos(alpha) ** 2 - cos(beta) ** 2 + 2 * cos(alpha) * cos(beta) * cos(gamma)
719                    )
720                    / sin(gamma),
721                ],
722            ]
723
724            if is_all_acute_or_obtuse(test_matrix):
725                transf = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]]
726                new_matrix = test_matrix
727
728            test_matrix = [
729                [a, 0, 0],
730                [-b * cos(gamma), -b * sin(gamma), 0.0],
731                [
732                    -c * cos(beta),
733                    -c * (cos(alpha) - cos(beta) * cos(gamma)) / sin(gamma),
734                    -c
735                    * math.sqrt(
736                        sin(gamma) ** 2 - cos(alpha) ** 2 - cos(beta) ** 2 + 2 * cos(alpha) * cos(beta) * cos(gamma)
737                    )
738                    / sin(gamma),
739                ],
740            ]
741            if is_all_acute_or_obtuse(test_matrix):
742                transf = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
743                new_matrix = test_matrix
744
745            latt = Lattice(new_matrix)
746
747        new_coords = np.dot(transf, np.transpose(struct.frac_coords)).T
748        new_struct = Structure(
749            latt,
750            struct.species_and_occu,
751            new_coords,
752            site_properties=struct.site_properties,
753            to_unit_cell=True,
754        )
755        return new_struct.get_sorted_structure()
756
757    def get_kpoint_weights(self, kpoints, atol=1e-5):
758        """
759        Calculate the weights for a list of kpoints.
760
761        Args:
762            kpoints (Sequence): Sequence of kpoints. np.arrays is fine. Note
763                that the code does not check that the list of kpoints
764                provided does not contain duplicates.
765            atol (float): Tolerance for fractional coordinates comparisons.
766
767        Returns:
768            List of weights, in the SAME order as kpoints.
769        """
770        kpts = np.array(kpoints)
771        shift = []
772        mesh = []
773        for i in range(3):
774            nonzero = [i for i in kpts[:, i] if abs(i) > 1e-5]
775            if len(nonzero) != len(kpts):
776                # gamma centered
777                if not nonzero:
778                    mesh.append(1)
779                else:
780                    m = np.abs(np.round(1 / np.array(nonzero)))
781                    mesh.append(int(max(m)))
782                shift.append(0)
783            else:
784                # Monk
785                m = np.abs(np.round(0.5 / np.array(nonzero)))
786                mesh.append(int(max(m)))
787                shift.append(1)
788
789        mapping, grid = spglib.get_ir_reciprocal_mesh(np.array(mesh), self._cell, is_shift=shift, symprec=self._symprec)
790        mapping = list(mapping)
791        grid = (np.array(grid) + np.array(shift) * (0.5, 0.5, 0.5)) / mesh
792        weights = []
793        mapped = defaultdict(int)
794        for k in kpoints:
795            for i, g in enumerate(grid):
796                if np.allclose(pbc_diff(k, g), (0, 0, 0), atol=atol):
797                    mapped[tuple(g)] += 1
798                    weights.append(mapping.count(mapping[i]))
799                    break
800        if (len(mapped) != len(set(mapping))) or (not all(v == 1 for v in mapped.values())):
801            raise ValueError("Unable to find 1:1 corresponding between input " "kpoints and irreducible grid!")
802        return [w / sum(weights) for w in weights]
803
804    def is_laue(self):
805        """
806        Check if the point group of the structure
807            has Laue symmetry (centrosymmetry)
808        """
809
810        laue = [
811            "-1",
812            "2/m",
813            "mmm",
814            "4/m",
815            "4/mmm",
816            "-3",
817            "-3m",
818            "6/m",
819            "6/mmm",
820            "m-3",
821            "m-3m",
822        ]
823
824        return str(self.get_point_group_symbol()) in laue
825
826
827class PointGroupAnalyzer:
828    """
829    A class to analyze the point group of a molecule. The general outline of
830    the algorithm is as follows:
831
832    1. Center the molecule around its center of mass.
833    2. Compute the inertia tensor and the eigenvalues and eigenvectors.
834    3. Handle the symmetry detection based on eigenvalues.
835
836        a. Linear molecules have one zero eigenvalue. Possible symmetry
837           operations are C*v or D*v
838        b. Asymetric top molecules have all different eigenvalues. The
839           maximum rotational symmetry in such molecules is 2
840        c. Symmetric top molecules have 1 unique eigenvalue, which gives a
841           unique rotation axis.  All axial point groups are possible
842           except the cubic groups (T & O) and I.
843        d. Spherical top molecules have all three eigenvalues equal. They
844           have the rare T, O or I point groups.
845
846    .. attribute:: sch_symbol
847
848        Schoenflies symbol of the detected point group.
849    """
850
851    inversion_op = SymmOp.inversion()
852
853    def __init__(self, mol, tolerance=0.3, eigen_tolerance=0.01, matrix_tol=0.1):
854        """
855        The default settings are usually sufficient.
856
857        Args:
858            mol (Molecule): Molecule to determine point group for.
859            tolerance (float): Distance tolerance to consider sites as
860                symmetrically equivalent. Defaults to 0.3 Angstrom.
861            eigen_tolerance (float): Tolerance to compare eigen values of
862                the inertia tensor. Defaults to 0.01.
863            matrix_tol (float): Tolerance used to generate the full set of
864                symmetry operations of the point group.
865        """
866        self.mol = mol
867        self.centered_mol = mol.get_centered_molecule()
868        self.tol = tolerance
869        self.eig_tol = eigen_tolerance
870        self.mat_tol = matrix_tol
871        self._analyze()
872        if self.sch_symbol in ["C1v", "C1h"]:
873            self.sch_symbol = "Cs"
874
875    def _analyze(self):
876        if len(self.centered_mol) == 1:
877            self.sch_symbol = "Kh"
878        else:
879            inertia_tensor = np.zeros((3, 3))
880            total_inertia = 0
881            for site in self.centered_mol:
882                c = site.coords
883                wt = site.species.weight
884                for i in range(3):
885                    inertia_tensor[i, i] += wt * (c[(i + 1) % 3] ** 2 + c[(i + 2) % 3] ** 2)
886                for i, j in [(0, 1), (1, 2), (0, 2)]:
887                    inertia_tensor[i, j] += -wt * c[i] * c[j]
888                    inertia_tensor[j, i] += -wt * c[j] * c[i]
889                total_inertia += wt * np.dot(c, c)
890
891            # Normalize the inertia tensor so that it does not scale with size
892            # of the system.  This mitigates the problem of choosing a proper
893            # comparison tolerance for the eigenvalues.
894            inertia_tensor /= total_inertia
895            eigvals, eigvecs = np.linalg.eig(inertia_tensor)
896            self.principal_axes = eigvecs.T
897            self.eigvals = eigvals
898            v1, v2, v3 = eigvals
899            eig_zero = abs(v1 * v2 * v3) < self.eig_tol
900            eig_all_same = abs(v1 - v2) < self.eig_tol and abs(v1 - v3) < self.eig_tol
901            eig_all_diff = abs(v1 - v2) > self.eig_tol and abs(v1 - v3) > self.eig_tol and abs(v2 - v3) > self.eig_tol
902
903            self.rot_sym = []
904            self.symmops = [SymmOp(np.eye(4))]
905            if eig_zero:
906                logger.debug("Linear molecule detected")
907                self._proc_linear()
908            elif eig_all_same:
909                logger.debug("Spherical top molecule detected")
910                self._proc_sph_top()
911            elif eig_all_diff:
912                logger.debug("Asymmetric top molecule detected")
913                self._proc_asym_top()
914            else:
915                logger.debug("Symmetric top molecule detected")
916                self._proc_sym_top()
917
918    def _proc_linear(self):
919        if self.is_valid_op(PointGroupAnalyzer.inversion_op):
920            self.sch_symbol = "D*h"
921            self.symmops.append(PointGroupAnalyzer.inversion_op)
922        else:
923            self.sch_symbol = "C*v"
924
925    def _proc_asym_top(self):
926        """
927        Handles assymetric top molecules, which cannot contain rotational
928        symmetry larger than 2.
929        """
930        self._check_R2_axes_asym()
931        if len(self.rot_sym) == 0:
932            logger.debug("No rotation symmetries detected.")
933            self._proc_no_rot_sym()
934        elif len(self.rot_sym) == 3:
935            logger.debug("Dihedral group detected.")
936            self._proc_dihedral()
937        else:
938            logger.debug("Cyclic group detected.")
939            self._proc_cyclic()
940
941    def _proc_sym_top(self):
942        """
943        Handles symetric top molecules which has one unique eigenvalue whose
944        corresponding principal axis is a unique rotational axis.  More complex
945        handling required to look for R2 axes perpendicular to this unique
946        axis.
947        """
948        if abs(self.eigvals[0] - self.eigvals[1]) < self.eig_tol:
949            ind = 2
950        elif abs(self.eigvals[1] - self.eigvals[2]) < self.eig_tol:
951            ind = 0
952        else:
953            ind = 1
954        logger.debug("Eigenvalues = %s." % self.eigvals)
955        unique_axis = self.principal_axes[ind]
956        self._check_rot_sym(unique_axis)
957        logger.debug("Rotation symmetries = %s" % self.rot_sym)
958        if len(self.rot_sym) > 0:
959            self._check_perpendicular_r2_axis(unique_axis)
960
961        if len(self.rot_sym) >= 2:
962            self._proc_dihedral()
963        elif len(self.rot_sym) == 1:
964            self._proc_cyclic()
965        else:
966            self._proc_no_rot_sym()
967
968    def _proc_no_rot_sym(self):
969        """
970        Handles molecules with no rotational symmetry. Only possible point
971        groups are C1, Cs and Ci.
972        """
973        self.sch_symbol = "C1"
974        if self.is_valid_op(PointGroupAnalyzer.inversion_op):
975            self.sch_symbol = "Ci"
976            self.symmops.append(PointGroupAnalyzer.inversion_op)
977        else:
978            for v in self.principal_axes:
979                mirror_type = self._find_mirror(v)
980                if not mirror_type == "":
981                    self.sch_symbol = "Cs"
982                    break
983
984    def _proc_cyclic(self):
985        """
986        Handles cyclic group molecules.
987        """
988        main_axis, rot = max(self.rot_sym, key=lambda v: v[1])
989        self.sch_symbol = "C{}".format(rot)
990        mirror_type = self._find_mirror(main_axis)
991        if mirror_type == "h":
992            self.sch_symbol += "h"
993        elif mirror_type == "v":
994            self.sch_symbol += "v"
995        elif mirror_type == "":
996            if self.is_valid_op(SymmOp.rotoreflection(main_axis, angle=180 / rot)):
997                self.sch_symbol = "S{}".format(2 * rot)
998
999    def _proc_dihedral(self):
1000        """
1001        Handles dihedral group molecules, i.e those with intersecting R2 axes
1002        and a main axis.
1003        """
1004        main_axis, rot = max(self.rot_sym, key=lambda v: v[1])
1005        self.sch_symbol = "D{}".format(rot)
1006        mirror_type = self._find_mirror(main_axis)
1007        if mirror_type == "h":
1008            self.sch_symbol += "h"
1009        elif not mirror_type == "":
1010            self.sch_symbol += "d"
1011
1012    def _check_R2_axes_asym(self):
1013        """
1014        Test for 2-fold rotation along the principal axes. Used to handle
1015        asymetric top molecules.
1016        """
1017        for v in self.principal_axes:
1018            op = SymmOp.from_axis_angle_and_translation(v, 180)
1019            if self.is_valid_op(op):
1020                self.symmops.append(op)
1021                self.rot_sym.append((v, 2))
1022
1023    def _find_mirror(self, axis):
1024        """
1025        Looks for mirror symmetry of specified type about axis.  Possible
1026        types are "h" or "vd".  Horizontal (h) mirrors are perpendicular to
1027        the axis while vertical (v) or diagonal (d) mirrors are parallel.  v
1028        mirrors has atoms lying on the mirror plane while d mirrors do
1029        not.
1030        """
1031        mirror_type = ""
1032
1033        # First test whether the axis itself is the normal to a mirror plane.
1034        if self.is_valid_op(SymmOp.reflection(axis)):
1035            self.symmops.append(SymmOp.reflection(axis))
1036            mirror_type = "h"
1037        else:
1038            # Iterate through all pairs of atoms to find mirror
1039            for s1, s2 in itertools.combinations(self.centered_mol, 2):
1040                if s1.species == s2.species:
1041                    normal = s1.coords - s2.coords
1042                    if np.dot(normal, axis) < self.tol:
1043                        op = SymmOp.reflection(normal)
1044                        if self.is_valid_op(op):
1045                            self.symmops.append(op)
1046                            if len(self.rot_sym) > 1:
1047                                mirror_type = "d"
1048                                for v, r in self.rot_sym:
1049                                    if np.linalg.norm(v - axis) >= self.tol:
1050                                        if np.dot(v, normal) < self.tol:
1051                                            mirror_type = "v"
1052                                            break
1053                            else:
1054                                mirror_type = "v"
1055                            break
1056
1057        return mirror_type
1058
1059    def _get_smallest_set_not_on_axis(self, axis):
1060        """
1061        Returns the smallest list of atoms with the same species and
1062        distance from origin AND does not lie on the specified axis.  This
1063        maximal set limits the possible rotational symmetry operations,
1064        since atoms lying on a test axis is irrelevant in testing rotational
1065        symmetryOperations.
1066        """
1067
1068        def not_on_axis(site):
1069            v = np.cross(site.coords, axis)
1070            return np.linalg.norm(v) > self.tol
1071
1072        valid_sets = []
1073        origin_site, dist_el_sites = cluster_sites(self.centered_mol, self.tol)
1074        for test_set in dist_el_sites.values():
1075            valid_set = list(filter(not_on_axis, test_set))
1076            if len(valid_set) > 0:
1077                valid_sets.append(valid_set)
1078
1079        return min(valid_sets, key=lambda s: len(s))
1080
1081    def _check_rot_sym(self, axis):
1082        """
1083        Determines the rotational symmetry about supplied axis.  Used only for
1084        symmetric top molecules which has possible rotational symmetry
1085        operations > 2.
1086        """
1087        min_set = self._get_smallest_set_not_on_axis(axis)
1088        max_sym = len(min_set)
1089        for i in range(max_sym, 0, -1):
1090            if max_sym % i != 0:
1091                continue
1092            op = SymmOp.from_axis_angle_and_translation(axis, 360 / i)
1093            rotvalid = self.is_valid_op(op)
1094            if rotvalid:
1095                self.symmops.append(op)
1096                self.rot_sym.append((axis, i))
1097                return i
1098        return 1
1099
1100    def _check_perpendicular_r2_axis(self, axis):
1101        """
1102        Checks for R2 axes perpendicular to unique axis.  For handling
1103        symmetric top molecules.
1104        """
1105        min_set = self._get_smallest_set_not_on_axis(axis)
1106        for s1, s2 in itertools.combinations(min_set, 2):
1107            test_axis = np.cross(s1.coords - s2.coords, axis)
1108            if np.linalg.norm(test_axis) > self.tol:
1109                op = SymmOp.from_axis_angle_and_translation(test_axis, 180)
1110                r2present = self.is_valid_op(op)
1111                if r2present:
1112                    self.symmops.append(op)
1113                    self.rot_sym.append((test_axis, 2))
1114                    return True
1115        return None
1116
1117    def _proc_sph_top(self):
1118        """
1119        Handles Sperhical Top Molecules, which belongs to the T, O or I point
1120        groups.
1121        """
1122        self._find_spherical_axes()
1123        if len(self.rot_sym) == 0:
1124            logger.debug("Accidental speherical top!")
1125            self._proc_sym_top()
1126        main_axis, rot = max(self.rot_sym, key=lambda v: v[1])
1127        if rot < 3:
1128            logger.debug("Accidental speherical top!")
1129            self._proc_sym_top()
1130        elif rot == 3:
1131            mirror_type = self._find_mirror(main_axis)
1132            if mirror_type != "":
1133                if self.is_valid_op(PointGroupAnalyzer.inversion_op):
1134                    self.symmops.append(PointGroupAnalyzer.inversion_op)
1135                    self.sch_symbol = "Th"
1136                else:
1137                    self.sch_symbol = "Td"
1138            else:
1139                self.sch_symbol = "T"
1140        elif rot == 4:
1141            if self.is_valid_op(PointGroupAnalyzer.inversion_op):
1142                self.symmops.append(PointGroupAnalyzer.inversion_op)
1143                self.sch_symbol = "Oh"
1144            else:
1145                self.sch_symbol = "O"
1146        elif rot == 5:
1147            if self.is_valid_op(PointGroupAnalyzer.inversion_op):
1148                self.symmops.append(PointGroupAnalyzer.inversion_op)
1149                self.sch_symbol = "Ih"
1150            else:
1151                self.sch_symbol = "I"
1152
1153    def _find_spherical_axes(self):
1154        """
1155        Looks for R5, R4, R3 and R2 axes in spherical top molecules.  Point
1156        group T molecules have only one unique 3-fold and one unique 2-fold
1157        axis. O molecules have one unique 4, 3 and 2-fold axes. I molecules
1158        have a unique 5-fold axis.
1159        """
1160        rot_present = defaultdict(bool)
1161        origin_site, dist_el_sites = cluster_sites(self.centered_mol, self.tol)
1162        test_set = min(dist_el_sites.values(), key=lambda s: len(s))
1163        coords = [s.coords for s in test_set]
1164        for c1, c2, c3 in itertools.combinations(coords, 3):
1165            for cc1, cc2 in itertools.combinations([c1, c2, c3], 2):
1166                if not rot_present[2]:
1167                    test_axis = cc1 + cc2
1168                    if np.linalg.norm(test_axis) > self.tol:
1169                        op = SymmOp.from_axis_angle_and_translation(test_axis, 180)
1170                        rot_present[2] = self.is_valid_op(op)
1171                        if rot_present[2]:
1172                            self.symmops.append(op)
1173                            self.rot_sym.append((test_axis, 2))
1174
1175            test_axis = np.cross(c2 - c1, c3 - c1)
1176            if np.linalg.norm(test_axis) > self.tol:
1177                for r in (3, 4, 5):
1178                    if not rot_present[r]:
1179                        op = SymmOp.from_axis_angle_and_translation(test_axis, 360 / r)
1180                        rot_present[r] = self.is_valid_op(op)
1181                        if rot_present[r]:
1182                            self.symmops.append(op)
1183                            self.rot_sym.append((test_axis, r))
1184                            break
1185            if rot_present[2] and rot_present[3] and (rot_present[4] or rot_present[5]):
1186                break
1187
1188    def get_pointgroup(self):
1189        """
1190        Returns a PointGroup object for the molecule.
1191        """
1192        return PointGroupOperations(self.sch_symbol, self.symmops, self.mat_tol)
1193
1194    def get_symmetry_operations(self):
1195        """
1196        Return symmetry operations as a list of SymmOp objects.
1197        Returns Cartesian coord symmops.
1198
1199        Returns:
1200            ([SymmOp]): List of symmetry operations.
1201        """
1202        return generate_full_symmops(self.symmops, self.tol)
1203
1204    def is_valid_op(self, symmop):
1205        """
1206        Check if a particular symmetry operation is a valid symmetry operation
1207        for a molecule, i.e., the operation maps all atoms to another
1208        equivalent atom.
1209
1210        Args:
1211            symmop (SymmOp): Symmetry operation to test.
1212
1213        Returns:
1214            (bool): Whether SymmOp is valid for Molecule.
1215        """
1216        coords = self.centered_mol.cart_coords
1217        for site in self.centered_mol:
1218            coord = symmop.operate(site.coords)
1219            ind = find_in_coord_list(coords, coord, self.tol)
1220            if not (len(ind) == 1 and self.centered_mol[ind[0]].species == site.species):
1221                return False
1222        return True
1223
1224    def _get_eq_sets(self):
1225        """
1226        Calculates the dictionary for mapping equivalent atoms onto each other.
1227
1228        Args:
1229            None
1230
1231        Returns:
1232            dict: The returned dictionary has two possible keys:
1233
1234            ``eq_sets``:
1235            A dictionary of indices mapping to sets of indices,
1236            each key maps to indices of all equivalent atoms.
1237            The keys are guaranteed to be not equivalent.
1238
1239            ``sym_ops``:
1240            Twofold nested dictionary.
1241            ``operations[i][j]`` gives the symmetry operation
1242            that maps atom ``i`` unto ``j``.
1243        """
1244        UNIT = np.eye(3)
1245        eq_sets, operations = defaultdict(set), defaultdict(dict)
1246        symm_ops = [op.rotation_matrix for op in generate_full_symmops(self.symmops, self.tol)]
1247
1248        def get_clustered_indices():
1249            indices = cluster_sites(self.centered_mol, self.tol, give_only_index=True)
1250            out = list(indices[1].values())
1251            if indices[0] is not None:
1252                out.append([indices[0]])
1253            return out
1254
1255        for index in get_clustered_indices():
1256            sites = self.centered_mol.cart_coords[index]
1257            for i, reference in zip(index, sites):
1258                for op in symm_ops:
1259                    rotated = np.dot(op, sites.T).T
1260                    matched_indices = find_in_coord_list(rotated, reference, self.tol)
1261                    matched_indices = {dict(enumerate(index))[i] for i in matched_indices}
1262                    eq_sets[i] |= matched_indices
1263
1264                    if i not in operations:
1265                        operations[i] = {j: op.T if j != i else UNIT for j in matched_indices}
1266                    else:
1267                        for j in matched_indices:
1268                            if j not in operations[i]:
1269                                operations[i][j] = op.T if j != i else UNIT
1270                    for j in matched_indices:
1271                        if j not in operations:
1272                            operations[j] = {i: op if j != i else UNIT}
1273                        elif i not in operations[j]:
1274                            operations[j][i] = op if j != i else UNIT
1275
1276        return {"eq_sets": eq_sets, "sym_ops": operations}
1277
1278    @staticmethod
1279    def _combine_eq_sets(eq_sets, operations):
1280        """Combines the dicts of _get_equivalent_atom_dicts into one
1281
1282        Args:
1283            eq_sets (dict)
1284            operations (dict)
1285
1286        Returns:
1287            dict: The returned dictionary has two possible keys:
1288
1289            ``eq_sets``:
1290            A dictionary of indices mapping to sets of indices,
1291            each key maps to indices of all equivalent atoms.
1292            The keys are guaranteed to be not equivalent.
1293
1294            ``sym_ops``:
1295            Twofold nested dictionary.
1296            ``operations[i][j]`` gives the symmetry operation
1297            that maps atom ``i`` unto ``j``.
1298        """
1299        UNIT = np.eye(3)
1300
1301        def all_equivalent_atoms_of_i(i, eq_sets, ops):
1302            """WORKS INPLACE on operations"""
1303            visited = set([i])
1304            tmp_eq_sets = {j: (eq_sets[j] - visited) for j in eq_sets[i]}
1305
1306            while tmp_eq_sets:
1307                new_tmp_eq_sets = {}
1308                for j in tmp_eq_sets:
1309                    if j in visited:
1310                        continue
1311                    visited.add(j)
1312                    for k in tmp_eq_sets[j]:
1313                        new_tmp_eq_sets[k] = eq_sets[k] - visited
1314                        if i not in ops[k]:
1315                            ops[k][i] = np.dot(ops[j][i], ops[k][j]) if k != i else UNIT
1316                        ops[i][k] = ops[k][i].T
1317                tmp_eq_sets = new_tmp_eq_sets
1318            return visited, ops
1319
1320        eq_sets = copy.deepcopy(eq_sets)
1321        ops = copy.deepcopy(operations)
1322        to_be_deleted = set()
1323        for i in eq_sets:
1324            if i in to_be_deleted:
1325                continue
1326            visited, ops = all_equivalent_atoms_of_i(i, eq_sets, ops)
1327            to_be_deleted |= visited - {i}
1328
1329        for k in to_be_deleted:
1330            eq_sets.pop(k, None)
1331        return {"eq_sets": eq_sets, "sym_ops": ops}
1332
1333    def get_equivalent_atoms(self):
1334        """Returns sets of equivalent atoms with symmetry operations
1335
1336        Args:
1337            None
1338
1339        Returns:
1340            dict: The returned dictionary has two possible keys:
1341
1342            ``eq_sets``:
1343            A dictionary of indices mapping to sets of indices,
1344            each key maps to indices of all equivalent atoms.
1345            The keys are guaranteed to be not equivalent.
1346
1347            ``sym_ops``:
1348            Twofold nested dictionary.
1349            ``operations[i][j]`` gives the symmetry operation
1350            that maps atom ``i`` unto ``j``.
1351        """
1352        eq = self._get_eq_sets()
1353        return self._combine_eq_sets(eq["eq_sets"], eq["sym_ops"])
1354
1355    def symmetrize_molecule(self):
1356        """Returns a symmetrized molecule
1357
1358        The equivalent atoms obtained via
1359        :meth:`~pymatgen.symmetry.analyzer.PointGroupAnalyzer.get_equivalent_atoms`
1360        are rotated, mirrored... unto one position.
1361        Then the average position is calculated.
1362        The average position is rotated, mirrored... back with the inverse
1363        of the previous symmetry operations, which gives the
1364        symmetrized molecule
1365
1366        Args:
1367            None
1368
1369        Returns:
1370            dict: The returned dictionary has three possible keys:
1371
1372            ``sym_mol``:
1373            A symmetrized molecule instance.
1374
1375            ``eq_sets``:
1376            A dictionary of indices mapping to sets of indices,
1377            each key maps to indices of all equivalent atoms.
1378            The keys are guaranteed to be not equivalent.
1379
1380            ``sym_ops``:
1381            Twofold nested dictionary.
1382            ``operations[i][j]`` gives the symmetry operation
1383            that maps atom ``i`` unto ``j``.
1384        """
1385        eq = self.get_equivalent_atoms()
1386        eq_sets, ops = eq["eq_sets"], eq["sym_ops"]
1387        coords = self.centered_mol.cart_coords.copy()
1388        for i, eq_indices in eq_sets.items():
1389            for j in eq_indices:
1390                coords[j] = np.dot(ops[j][i], coords[j])
1391            coords[i] = np.mean(coords[list(eq_indices)], axis=0)
1392            for j in eq_indices:
1393                if j == i:
1394                    continue
1395                coords[j] = np.dot(ops[i][j], coords[i])
1396                coords[j] = np.dot(ops[i][j], coords[i])
1397        molecule = Molecule(species=self.centered_mol.species_and_occu, coords=coords)
1398        return {"sym_mol": molecule, "eq_sets": eq_sets, "sym_ops": ops}
1399
1400
1401def iterative_symmetrize(mol, max_n=10, tolerance=0.3, epsilon=1e-2):
1402    """Returns a symmetrized molecule
1403
1404    The equivalent atoms obtained via
1405    :meth:`~pymatgen.symmetry.analyzer.PointGroupAnalyzer.get_equivalent_atoms`
1406    are rotated, mirrored... unto one position.
1407    Then the average position is calculated.
1408    The average position is rotated, mirrored... back with the inverse
1409    of the previous symmetry operations, which gives the
1410    symmetrized molecule
1411
1412    Args:
1413        mol (Molecule): A pymatgen Molecule instance.
1414        max_n (int): Maximum number of iterations.
1415        tolerance (float): Tolerance for detecting symmetry.
1416            Gets passed as Argument into
1417            :class:`~pymatgen.analyzer.symmetry.PointGroupAnalyzer`.
1418        epsilon (float): If the elementwise absolute difference of two
1419            subsequently symmetrized structures is smaller epsilon,
1420            the iteration stops before ``max_n`` is reached.
1421
1422
1423    Returns:
1424        dict: The returned dictionary has three possible keys:
1425
1426        ``sym_mol``:
1427        A symmetrized molecule instance.
1428
1429        ``eq_sets``:
1430        A dictionary of indices mapping to sets of indices,
1431        each key maps to indices of all equivalent atoms.
1432        The keys are guaranteed to be not equivalent.
1433
1434        ``sym_ops``:
1435        Twofold nested dictionary.
1436        ``operations[i][j]`` gives the symmetry operation
1437        that maps atom ``i`` unto ``j``.
1438    """
1439    new = mol
1440    n = 0
1441    finished = False
1442    while not finished and n <= max_n:
1443        previous = new
1444        PA = PointGroupAnalyzer(previous, tolerance=tolerance)
1445        eq = PA.symmetrize_molecule()
1446        new = eq["sym_mol"]
1447        finished = np.allclose(new.cart_coords, previous.cart_coords, atol=epsilon)
1448        n += 1
1449    return eq
1450
1451
1452def cluster_sites(mol, tol, give_only_index=False):
1453    """
1454    Cluster sites based on distance and species type.
1455
1456    Args:
1457        mol (Molecule): Molecule **with origin at center of mass**.
1458        tol (float): Tolerance to use.
1459
1460    Returns:
1461        (origin_site, clustered_sites): origin_site is a site at the center
1462        of mass (None if there are no origin atoms). clustered_sites is a
1463        dict of {(avg_dist, species_and_occu): [list of sites]}
1464    """
1465    # Cluster works for dim > 2 data. We just add a dummy 0 for second
1466    # coordinate.
1467    dists = [[np.linalg.norm(site.coords), 0] for site in mol]
1468    import scipy.cluster as spcluster
1469
1470    f = spcluster.hierarchy.fclusterdata(dists, tol, criterion="distance")
1471    clustered_dists = defaultdict(list)
1472    for i, site in enumerate(mol):
1473        clustered_dists[f[i]].append(dists[i])
1474    avg_dist = {label: np.mean(val) for label, val in clustered_dists.items()}
1475    clustered_sites = defaultdict(list)
1476    origin_site = None
1477    for i, site in enumerate(mol):
1478        if avg_dist[f[i]] < tol:
1479            if give_only_index:
1480                origin_site = i
1481            else:
1482                origin_site = site
1483        else:
1484            if give_only_index:
1485                clustered_sites[(avg_dist[f[i]], site.species)].append(i)
1486            else:
1487                clustered_sites[(avg_dist[f[i]], site.species)].append(site)
1488    return origin_site, clustered_sites
1489
1490
1491def generate_full_symmops(symmops, tol):
1492    """
1493    Recursive algorithm to permute through all possible combinations of the
1494    initially supplied symmetry operations to arrive at a complete set of
1495    operations mapping a single atom to all other equivalent atoms in the
1496    point group.  This assumes that the initial number already uniquely
1497    identifies all operations.
1498
1499    Args:
1500        symmops ([SymmOp]): Initial set of symmetry operations.
1501
1502    Returns:
1503        Full set of symmetry operations.
1504    """
1505    # Uses an algorithm described in:
1506    # Gregory Butler. Fundamental Algorithms for Permutation Groups.
1507    # Lecture Notes in Computer Science (Book 559). Springer, 1991. page 15
1508    UNIT = np.eye(4)
1509    generators = [op.affine_matrix for op in symmops if not np.allclose(op.affine_matrix, UNIT)]
1510    if not generators:
1511        # C1 symmetry breaks assumptions in the algorithm afterwards
1512        return symmops
1513
1514    full = list(generators)
1515
1516    for g in full:
1517        for s in generators:
1518            op = np.dot(g, s)
1519            d = np.abs(full - op) < tol
1520            if not np.any(np.all(np.all(d, axis=2), axis=1)):
1521                full.append(op)
1522
1523    d = np.abs(full - UNIT) < tol
1524    if not np.any(np.all(np.all(d, axis=2), axis=1)):
1525        full.append(UNIT)
1526    return [SymmOp(op) for op in full]
1527
1528
1529class SpacegroupOperations(list):
1530    """
1531    Represents a space group, which is a collection of symmetry operations.
1532    """
1533
1534    def __init__(self, int_symbol, int_number, symmops):
1535        """
1536        Args:
1537            int_symbol (str): International symbol of the spacegroup.
1538            int_number (int): International number of the spacegroup.
1539            symmops ([SymmOp]): Symmetry operations associated with the
1540                spacegroup.
1541        """
1542        self.int_symbol = int_symbol
1543        self.int_number = int_number
1544        super().__init__(symmops)
1545
1546    def are_symmetrically_equivalent(self, sites1, sites2, symm_prec=1e-3):
1547        """
1548        Given two sets of PeriodicSites, test if they are actually
1549        symmetrically equivalent under this space group.  Useful, for example,
1550        if you want to test if selecting atoms 1 and 2 out of a set of 4 atoms
1551        are symmetrically the same as selecting atoms 3 and 4, etc.
1552
1553        One use is in PartialRemoveSpecie transformation to return only
1554        symmetrically distinct arrangements of atoms.
1555
1556        Args:
1557            sites1 ([PeriodicSite]): 1st set of sites
1558            sites2 ([PeriodicSite]): 2nd set of sites
1559            symm_prec (float): Tolerance in atomic distance to test if atoms
1560                are symmetrically similar.
1561
1562        Returns:
1563            (bool): Whether the two sets of sites are symmetrically
1564            equivalent.
1565        """
1566
1567        def in_sites(site):
1568            for test_site in sites1:
1569                if test_site.is_periodic_image(site, symm_prec, False):
1570                    return True
1571            return False
1572
1573        for op in self:
1574            newsites2 = [PeriodicSite(site.species, op.operate(site.frac_coords), site.lattice) for site in sites2]
1575            for site in newsites2:
1576                if not in_sites(site):
1577                    break
1578            else:
1579                return True
1580        return False
1581
1582    def __str__(self):
1583        return "{} ({}) spacegroup".format(self.int_symbol, self.int_number)
1584
1585
1586class PointGroupOperations(list):
1587    """
1588    Defines a point group, which is essentially a sequence of symmetry
1589    operations.
1590
1591    .. attribute:: sch_symbol
1592
1593        Schoenflies symbol of the point group.
1594    """
1595
1596    def __init__(self, sch_symbol, operations, tol=0.1):
1597        """
1598        Args:
1599            sch_symbol (str): Schoenflies symbol of the point group.
1600            operations ([SymmOp]): Initial set of symmetry operations. It is
1601                sufficient to provide only just enough operations to generate
1602                the full set of symmetries.
1603            tol (float): Tolerance to generate the full set of symmetry
1604                operations.
1605        """
1606        self.sch_symbol = sch_symbol
1607        super().__init__(generate_full_symmops(operations, tol))
1608
1609    def __str__(self):
1610        return self.sch_symbol
1611
1612    def __repr__(self):
1613        return self.__str__()
1614