1"""Determine symmetry equivalence of two structures.
2Based on the recipe from Comput. Phys. Commun. 183, 690-697 (2012)."""
3from collections import Counter
4from itertools import combinations, product, filterfalse
5
6import numpy as np
7from scipy.spatial import cKDTree as KDTree
8
9from ase import Atom, Atoms
10from ase.build.tools import niggli_reduce
11
12
13def normalize(cell):
14    for i in range(3):
15        cell[i] /= np.linalg.norm(cell[i])
16
17
18class SpgLibNotFoundError(Exception):
19    """Raised if SPG lib is not found when needed."""
20
21    def __init__(self, msg):
22        super(SpgLibNotFoundError, self).__init__(msg)
23
24
25class SymmetryEquivalenceCheck:
26    """Compare two structures to determine if they are symmetry equivalent.
27
28    Based on the recipe from Comput. Phys. Commun. 183, 690-697 (2012).
29
30    Parameters:
31
32    angle_tol: float
33        angle tolerance for the lattice vectors in degrees
34
35    ltol: float
36        relative tolerance for the length of the lattice vectors (per atom)
37
38    stol: float
39        position tolerance for the site comparison in units of
40        (V/N)^(1/3) (average length between atoms)
41
42    vol_tol: float
43        volume tolerance in angstrom cubed to compare the volumes of
44        the two structures
45
46    scale_volume: bool
47        if True the volumes of the two structures are scaled to be equal
48
49    to_primitive: bool
50        if True the structures are reduced to their primitive cells
51        note that this feature requires spglib to installed
52
53    Examples:
54
55    >>> from ase.build import bulk
56    >>> from ase.utils.structure_comparator import SymmetryEquivalenceCheck
57    >>> comp = SymmetryEquivalenceCheck()
58
59    Compare a cell with a rotated version
60
61    >>> a = bulk('Al', orthorhombic=True)
62    >>> b = a.copy()
63    >>> b.rotate(60, 'x', rotate_cell=True)
64    >>> comp.compare(a, b)
65    True
66
67    Transform to the primitive cell and then compare
68
69    >>> pa = bulk('Al')
70    >>> comp.compare(a, pa)
71    False
72    >>> comp = SymmetryEquivalenceCheck(to_primitive=True)
73    >>> comp.compare(a, pa)
74    True
75
76    Compare one structure with a list of other structures
77
78    >>> import numpy as np
79    >>> from ase import Atoms
80    >>> s1 = Atoms('H3', positions=[[0.5, 0.5, 0],
81    ...                             [0.5, 1.5, 0],
82    ...                             [1.5, 1.5, 0]],
83    ...            cell=[2, 2, 2], pbc=True)
84    >>> comp = SymmetryEquivalenceCheck(stol=0.068)
85    >>> s2_list = []
86    >>> for d in np.linspace(0.1, 1.0, 5):
87    ...     s2 = s1.copy()
88    ...     s2.positions[0] += [d, 0, 0]
89    ...     s2_list.append(s2)
90    >>> comp.compare(s1, s2_list[:-1])
91    False
92    >>> comp.compare(s1, s2_list)
93    True
94
95    """
96
97    def __init__(self, angle_tol=1.0, ltol=0.05, stol=0.05, vol_tol=0.1,
98                 scale_volume=False, to_primitive=False):
99        self.angle_tol = angle_tol * np.pi / 180.0  # convert to radians
100        self.scale_volume = scale_volume
101        self.stol = stol
102        self.ltol = ltol
103        self.vol_tol = vol_tol
104        self.position_tolerance = 0.0
105        self.to_primitive = to_primitive
106
107        # Variables to be used in the compare function
108        self.s1 = None
109        self.s2 = None
110        self.expanded_s1 = None
111        self.expanded_s2 = None
112        self.least_freq_element = None
113
114    def _niggli_reduce(self, atoms):
115        """Reduce to niggli cells.
116
117        Reduce the atoms to niggli cells, then rotates the niggli cells to
118        the so called "standard" orientation with one lattice vector along the
119        x-axis and a second vector in the xy plane.
120        """
121        niggli_reduce(atoms)
122        self._standarize_cell(atoms)
123
124    def _standarize_cell(self, atoms):
125        """Rotate the first vector such that it points along the x-axis.
126        Then rotate around the first vector so the second vector is in the
127        xy plane.
128        """
129        # Rotate first vector to x axis
130        cell = atoms.get_cell().T
131        total_rot_mat = np.eye(3)
132        v1 = cell[:, 0]
133        l1 = np.sqrt(v1[0]**2 + v1[2]**2)
134        angle = np.abs(np.arcsin(v1[2] / l1))
135        if (v1[0] < 0.0 and v1[2] > 0.0):
136            angle = np.pi - angle
137        elif (v1[0] < 0.0 and v1[2] < 0.0):
138            angle = np.pi + angle
139        elif (v1[0] > 0.0 and v1[2] < 0.0):
140            angle = -angle
141        ca = np.cos(angle)
142        sa = np.sin(angle)
143        rotmat = np.array([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]])
144        total_rot_mat = rotmat.dot(total_rot_mat)
145        cell = rotmat.dot(cell)
146
147        v1 = cell[:, 0]
148        l1 = np.sqrt(v1[0]**2 + v1[1]**2)
149        angle = np.abs(np.arcsin(v1[1] / l1))
150        if (v1[0] < 0.0 and v1[1] > 0.0):
151            angle = np.pi - angle
152        elif (v1[0] < 0.0 and v1[1] < 0.0):
153            angle = np.pi + angle
154        elif (v1[0] > 0.0 and v1[1] < 0.0):
155            angle = -angle
156        ca = np.cos(angle)
157        sa = np.sin(angle)
158        rotmat = np.array([[ca, sa, 0.0], [-sa, ca, 0.0], [0.0, 0.0, 1.0]])
159        total_rot_mat = rotmat.dot(total_rot_mat)
160        cell = rotmat.dot(cell)
161
162        # Rotate around x axis such that the second vector is in the xy plane
163        v2 = cell[:, 1]
164        l2 = np.sqrt(v2[1]**2 + v2[2]**2)
165        angle = np.abs(np.arcsin(v2[2] / l2))
166        if (v2[1] < 0.0 and v2[2] > 0.0):
167            angle = np.pi - angle
168        elif (v2[1] < 0.0 and v2[2] < 0.0):
169            angle = np.pi + angle
170        elif (v2[1] > 0.0 and v2[2] < 0.0):
171            angle = -angle
172        ca = np.cos(angle)
173        sa = np.sin(angle)
174        rotmat = np.array([[1.0, 0.0, 0.0], [0.0, ca, sa], [0.0, -sa, ca]])
175        total_rot_mat = rotmat.dot(total_rot_mat)
176        cell = rotmat.dot(cell)
177
178        atoms.set_cell(cell.T)
179        atoms.set_positions(total_rot_mat.dot(atoms.get_positions().T).T)
180        atoms.wrap(pbc=[1, 1, 1])
181        return atoms
182
183    def _get_element_count(self, struct):
184        """Count the number of elements in each of the structures."""
185        return Counter(struct.numbers)
186
187    def _get_angles(self, cell):
188        """Get the internal angles of the unit cell."""
189        cell = cell.copy()
190
191        normalize(cell)
192
193        dot = cell.dot(cell.T)
194
195        # Extract only the relevant dot products
196        dot = [dot[0, 1], dot[0, 2], dot[1, 2]]
197
198        # Return angles
199        return np.arccos(dot)
200
201    def _has_same_elements(self):
202        """Check if two structures have same elements."""
203        elem1 = self._get_element_count(self.s1)
204        return elem1 == self._get_element_count(self.s2)
205
206    def _has_same_angles(self):
207        """Check that the Niggli unit vectors has the same internal angles."""
208        ang1 = np.sort(self._get_angles(self.s1.get_cell()))
209        ang2 = np.sort(self._get_angles(self.s2.get_cell()))
210
211        return np.allclose(ang1, ang2, rtol=0, atol=self.angle_tol)
212
213    def _has_same_volume(self):
214        vol1 = self.s1.get_volume()
215        vol2 = self.s2.get_volume()
216        return np.abs(vol1 - vol2) < self.vol_tol
217
218    def _scale_volumes(self):
219        """Scale the cell of s2 to have the same volume as s1."""
220        cell2 = self.s2.get_cell()
221        # Get the volumes
222        v2 = np.linalg.det(cell2)
223        v1 = np.linalg.det(self.s1.get_cell())
224
225        # Scale the cells
226        coordinate_scaling = (v1 / v2)**(1.0 / 3.0)
227        cell2 *= coordinate_scaling
228        self.s2.set_cell(cell2, scale_atoms=True)
229
230    def compare(self, s1, s2):
231        """Compare the two structures.
232
233        Return *True* if the two structures are equivalent, *False* otherwise.
234
235        Parameters:
236
237        s1: Atoms object.
238            Transformation matrices are calculated based on this structure.
239
240        s2: Atoms or list
241            s1 can be compared to one structure or many structures supplied in
242            a list. If s2 is a list it returns True if any structure in s2
243            matches s1, False otherwise.
244        """
245        if self.to_primitive:
246            s1 = self._reduce_to_primitive(s1)
247        self._set_least_frequent_element(s1)
248        self._least_frequent_element_to_origin(s1)
249        self.s1 = s1.copy()
250        vol = self.s1.get_volume()
251        self.expanded_s1 = None
252        s1_niggli_reduced = False
253
254        if isinstance(s2, Atoms):
255            # Just make it a list of length 1
256            s2 = [s2]
257
258        matrices = None
259        translations = None
260        transposed_matrices = None
261        for struct in s2:
262            self.s2 = struct.copy()
263            self.expanded_s2 = None
264
265            if self.to_primitive:
266                self.s2 = self._reduce_to_primitive(self.s2)
267
268            # Compare number of elements in structures
269            if len(self.s1) != len(self.s2):
270                continue
271
272            # Compare chemical formulae
273            if not self._has_same_elements():
274                continue
275
276            # Compare angles
277            if not s1_niggli_reduced:
278                self._niggli_reduce(self.s1)
279            self._niggli_reduce(self.s2)
280            if not self._has_same_angles():
281                continue
282
283            # Compare volumes
284            if self.scale_volume:
285                self._scale_volumes()
286            if not self._has_same_volume():
287                continue
288
289            if matrices is None:
290                matrices = self._get_rotation_reflection_matrices()
291                if matrices is None:
292                    continue
293
294            if translations is None:
295                translations = self._get_least_frequent_positions(self.s1)
296
297            # After the candidate translation based on s1 has been computed
298            # we need potentially to swap s1 and s2 for robust comparison
299            self._least_frequent_element_to_origin(self.s2)
300            switch = self._switch_reference_struct()
301            if switch:
302                # Remember the matrices and translations used before
303                old_matrices = matrices
304                old_translations = translations
305
306                # If a s1 and s2 has been switched we need to use the
307                # transposed version of the matrices to map atoms the
308                # other way
309                if transposed_matrices is None:
310                    transposed_matrices = np.transpose(matrices,
311                                                       axes=[0, 2, 1])
312                matrices = transposed_matrices
313                translations = self._get_least_frequent_positions(self.s1)
314
315            # Calculate tolerance on positions
316            self.position_tolerance = \
317                self.stol * (vol / len(self.s2))**(1.0 / 3.0)
318
319            if self._positions_match(matrices, translations):
320                return True
321
322            # Set the reference structure back to its original
323            self.s1 = s1.copy()
324            if switch:
325                self.expanded_s1 = self.expanded_s2
326                matrices = old_matrices
327                translations = old_translations
328        return False
329
330    def _set_least_frequent_element(self, atoms):
331        """Save the atomic number of the least frequent element."""
332        elem1 = self._get_element_count(atoms)
333        self.least_freq_element = elem1.most_common()[-1][0]
334
335    def _get_least_frequent_positions(self, atoms):
336        """Get the positions of the least frequent element in atoms."""
337        pos = atoms.get_positions(wrap=True)
338        return pos[atoms.numbers == self.least_freq_element]
339
340    def _get_only_least_frequent_of(self, struct):
341        """Get the atoms object with all other elements than the least frequent
342        one removed. Wrap the positions to get everything in the cell."""
343        pos = struct.get_positions(wrap=True)
344
345        indices = struct.numbers == self.least_freq_element
346        least_freq_struct = struct[indices]
347        least_freq_struct.set_positions(pos[indices])
348
349        return least_freq_struct
350
351    def _switch_reference_struct(self):
352        """There is an intrinsic assymetry in the system because
353        one of the atoms are being expanded, while the other is not.
354        This can cause the algorithm to return different result
355        depending on which structure is passed first.
356        We adopt the convention of using the atoms object
357        having the fewest atoms in its expanded cell as the
358        reference object.
359        We return True if a switch of structures has been performed."""
360
361        # First expand the cells
362        if self.expanded_s1 is None:
363            self.expanded_s1 = self._expand(self.s1)
364        if self.expanded_s2 is None:
365            self.expanded_s2 = self._expand(self.s2)
366
367        exp1 = self.expanded_s1
368        exp2 = self.expanded_s2
369        if len(exp1) < len(exp2):
370            # s1 should be the reference structure
371            # We have to swap s1 and s2
372            s1_temp = self.s1.copy()
373            self.s1 = self.s2
374            self.s2 = s1_temp
375            exp1_temp = self.expanded_s1.copy()
376            self.expanded_s1 = self.expanded_s2
377            self.expanded_s2 = exp1_temp
378            return True
379        return False
380
381    def _positions_match(self, rotation_reflection_matrices, translations):
382        """Check if the position and elements match.
383
384        Note that this function changes self.s1 and self.s2 to the rotation and
385        translation that matches best. Hence, it is crucial that this function
386        calls the element comparison, not the other way around.
387        """
388        pos1_ref = self.s1.get_positions(wrap=True)
389
390        # Get the expanded reference object
391        exp2 = self.expanded_s2
392        # Build a KD tree to enable fast look-up of nearest neighbours
393        tree = KDTree(exp2.get_positions())
394        for i in range(translations.shape[0]):
395            # Translate
396            pos1_trans = pos1_ref - translations[i]
397            for matrix in rotation_reflection_matrices:
398                # Rotate
399                pos1 = matrix.dot(pos1_trans.T).T
400
401                # Update the atoms positions
402                self.s1.set_positions(pos1)
403                self.s1.wrap(pbc=[1, 1, 1])
404                if self._elements_match(self.s1, exp2, tree):
405                    return True
406        return False
407
408    def _expand(self, ref_atoms, tol=0.0001):
409        """If an atom is closer to a boundary than tol it is repeated at the
410        opposite boundaries.
411
412        This ensures that atoms having crossed the cell boundaries due to
413        numerical noise are properly detected.
414
415        The distance between a position and cell boundary is calculated as:
416        dot(position, (b_vec x c_vec) / (|b_vec| |c_vec|) ), where x is the
417        cross product.
418        """
419        syms = ref_atoms.get_chemical_symbols()
420        cell = ref_atoms.get_cell()
421        positions = ref_atoms.get_positions(wrap=True)
422        expanded_atoms = ref_atoms.copy()
423
424        # Calculate normal vectors to the unit cell faces
425        normal_vectors = np.array([np.cross(cell[1, :], cell[2, :]),
426                                   np.cross(cell[0, :], cell[2, :]),
427                                   np.cross(cell[0, :], cell[1, :])])
428        normalize(normal_vectors)
429
430        # Get the distance to the unit cell faces from each atomic position
431        pos2faces = np.abs(positions.dot(normal_vectors.T))
432
433        # And the opposite faces
434        pos2oppofaces = np.abs(np.dot(positions - np.sum(cell, axis=0),
435                                      normal_vectors.T))
436
437        for i, i2face in enumerate(pos2faces):
438            # Append indices for positions close to the other faces
439            # and convert to boolean array signifying if the position at
440            # index i is close to the faces bordering origo (0, 1, 2) or
441            # the opposite faces (3, 4, 5)
442            i_close2face = np.append(i2face, pos2oppofaces[i]) < tol
443            # For each position i.e. row it holds that
444            # 1 x True -> close to face -> 1 extra atom at opposite face
445            # 2 x True -> close to edge -> 3 extra atoms at opposite edges
446            # 3 x True -> close to corner -> 7 extra atoms opposite corners
447            # E.g. to add atoms at all corners we need to use the cell
448            # vectors: (a, b, c, a + b, a + c, b + c, a + b + c), we use
449            # itertools.combinations to get them all
450            for j in range(sum(i_close2face)):
451                for c in combinations(np.nonzero(i_close2face)[0], j + 1):
452                    # Get the displacement vectors by adding the corresponding
453                    # cell vectors, if the atom is close to an opposite face
454                    # i.e. k > 2 subtract the cell vector
455                    disp_vec = np.zeros(3)
456                    for k in c:
457                        disp_vec += cell[k % 3] * (int(k < 3) * 2 - 1)
458                    pos = positions[i] + disp_vec
459                    expanded_atoms.append(Atom(syms[i], position=pos))
460        return expanded_atoms
461
462    def _equal_elements_in_array(self, arr):
463        s = np.sort(arr)
464        return np.any(s[1:] == s[:-1])
465
466    def _elements_match(self, s1, s2, kdtree):
467        """Check if all the elements in s1 match the corresponding position in s2
468
469        NOTE: The unit cells may be in different octants
470        Hence, try all cyclic permutations of x,y and z
471        """
472        pos1 = s1.get_positions()
473        for order in range(1):  # Is the order still needed?
474            pos_order = [order, (order + 1) % 3, (order + 2) % 3]
475            pos = pos1[:, np.argsort(pos_order)]
476            dists, closest_in_s2 = kdtree.query(pos)
477
478            # Check if the elements are the same
479            if not np.all(s2.numbers[closest_in_s2] == s1.numbers):
480                return False
481
482            # Check if any distance is too large
483            if np.any(dists > self.position_tolerance):
484                return False
485
486            # Check for duplicates in what atom is closest
487            if self._equal_elements_in_array(closest_in_s2):
488                return False
489
490        return True
491
492    def _least_frequent_element_to_origin(self, atoms):
493        """Put one of the least frequent elements at the origin."""
494        least_freq_pos = self._get_least_frequent_positions(atoms)
495        cell_diag = np.sum(atoms.get_cell(), axis=0)
496        d = least_freq_pos[0] - 1e-6 * cell_diag
497        atoms.positions -= d
498        atoms.wrap(pbc=[1, 1, 1])
499
500    def _get_rotation_reflection_matrices(self):
501        """Compute candidates for the transformation matrix."""
502        atoms1_ref = self._get_only_least_frequent_of(self.s1)
503        cell = self.s1.get_cell().T
504        cell_diag = np.sum(cell, axis=1)
505        angle_tol = self.angle_tol
506
507        # Additional vector that is added to make sure that
508        # there always is an atom at the origin
509        delta_vec = 1E-6 * cell_diag
510
511        # Store three reference vectors and their lengths
512        ref_vec = self.s2.get_cell()
513        ref_vec_lengths = np.linalg.norm(ref_vec, axis=1)
514
515        # Compute ref vec angles
516        # ref_angles are arranged as [angle12, angle13, angle23]
517        ref_angles = np.array(self._get_angles(ref_vec))
518        large_angles = ref_angles > np.pi / 2.0
519        ref_angles[large_angles] = np.pi - ref_angles[large_angles]
520
521        # Translate by one cell diagonal so that a central cell is
522        # surrounded by cells in all directions
523        sc_atom_search = atoms1_ref * (3, 3, 3)
524        new_sc_pos = sc_atom_search.get_positions()
525        new_sc_pos -= new_sc_pos[0] + cell_diag - delta_vec
526
527        lengths = np.linalg.norm(new_sc_pos, axis=1)
528
529        candidate_indices = []
530        rtol = self.ltol / len(self.s1)
531        for k in range(3):
532            correct_lengths_mask = np.isclose(lengths,
533                                              ref_vec_lengths[k],
534                                              rtol=rtol, atol=0)
535            # The first vector is not interesting
536            correct_lengths_mask[0] = False
537
538            # If no trial vectors can be found (for any direction)
539            # then the candidates are different and we return None
540            if not np.any(correct_lengths_mask):
541                return None
542
543            candidate_indices.append(np.nonzero(correct_lengths_mask)[0])
544
545        # Now we calculate all relevant angles in one step. The relevant angles
546        # are the ones made by the current candidates. We will have to keep
547        # track of the indices in the angles matrix and the indices in the
548        # position and length arrays.
549
550        # Get all candidate indices (aci), only unique values
551        aci = np.sort(list(set().union(*candidate_indices)))
552
553        # Make a dictionary from original positions and lengths index to
554        # index in angle matrix
555        i2ang = dict(zip(aci, range(len(aci))))
556
557        # Calculate the dot product divided by the lengths:
558        # cos(angle) = dot(vec1, vec2) / |vec1| |vec2|
559        cosa = np.inner(new_sc_pos[aci],
560                        new_sc_pos[aci]) / np.outer(lengths[aci],
561                                                    lengths[aci])
562        # Make sure the inverse cosine will work
563        cosa[cosa > 1] = 1
564        cosa[cosa < -1] = -1
565        angles = np.arccos(cosa)
566        # Do trick for enantiomorphic structures
567        angles[angles > np.pi / 2] = np.pi - angles[angles > np.pi / 2]
568
569        # Check which angles match the reference angles
570        # Test for all combinations on candidates. filterfalse makes sure
571        # that there are no duplicate candidates. product is the same as
572        # nested for loops.
573        refined_candidate_list = []
574        for p in filterfalse(self._equal_elements_in_array,
575                             product(*candidate_indices)):
576            a = np.array([angles[i2ang[p[0]], i2ang[p[1]]],
577                          angles[i2ang[p[0]], i2ang[p[2]]],
578                          angles[i2ang[p[1]], i2ang[p[2]]]])
579
580            if np.allclose(a, ref_angles, atol=angle_tol, rtol=0):
581                refined_candidate_list.append(new_sc_pos[np.array(p)].T)
582
583        # Get the rotation/reflection matrix [R] by:
584        # [R] = [V][T]^-1, where [V] is the reference vectors and
585        # [T] is the trial vectors
586        # XXX What do we know about the length/shape of refined_candidate_list?
587        if len(refined_candidate_list) == 0:
588            return None
589        elif len(refined_candidate_list) == 1:
590            inverted_trial = 1.0 / refined_candidate_list
591        else:
592            inverted_trial = np.linalg.inv(refined_candidate_list)
593
594        # Equivalent to np.matmul(ref_vec.T, inverted_trial)
595        candidate_trans_mat = np.dot(ref_vec.T, inverted_trial.T).T
596        return candidate_trans_mat
597
598    def _reduce_to_primitive(self, structure):
599        """Reduce the two structure to their primitive type"""
600        try:
601            import spglib
602        except ImportError:
603            raise SpgLibNotFoundError(
604                "SpgLib is required if to_primitive=True")
605        cell = (structure.get_cell()).tolist()
606        pos = structure.get_scaled_positions().tolist()
607        numbers = structure.get_atomic_numbers()
608
609        cell, scaled_pos, numbers = spglib.standardize_cell(
610            (cell, pos, numbers), to_primitive=True)
611
612        atoms = Atoms(
613            scaled_positions=scaled_pos,
614            numbers=numbers,
615            cell=cell,
616            pbc=True)
617        return atoms
618