1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4
5"""
6This module provides classes to perform fitting of structures.
7"""
8
9import abc
10import itertools
11
12import numpy as np
13from monty.json import MSONable
14
15from pymatgen.analysis.defects.core import Defect, Interstitial, Substitution, Vacancy
16from pymatgen.core import PeriodicSite
17from pymatgen.core.composition import Composition
18from pymatgen.core.lattice import Lattice
19from pymatgen.core.periodic_table import get_el_sp
20from pymatgen.core.structure import Structure
21from pymatgen.optimization.linear_assignment import LinearAssignment  # type: ignore
22from pymatgen.util.coord import lattice_points_in_supercell
23from pymatgen.util.coord_cython import (  # type: ignore
24    is_coord_subset_pbc,
25    pbc_shortest_vectors,
26)
27
28__author__ = "William Davidson Richards, Stephen Dacek, Shyue Ping Ong"
29__copyright__ = "Copyright 2011, The Materials Project"
30__version__ = "1.0"
31__maintainer__ = "William Davidson Richards"
32__email__ = "wrichard@mit.edu"
33__status__ = "Production"
34__date__ = "Dec 3, 2012"
35
36
37class AbstractComparator(MSONable, metaclass=abc.ABCMeta):
38    """
39    Abstract Comparator class. A Comparator defines how sites are compared in
40    a structure.
41    """
42
43    @abc.abstractmethod
44    def are_equal(self, sp1, sp2):
45        """
46        Defines how the species of two sites are considered equal. For
47        example, one can consider sites to have the same species only when
48        the species are exactly the same, i.e., Fe2+ matches Fe2+ but not
49        Fe3+. Or one can define that only the element matters,
50        and all oxidation state information are ignored.
51
52        Args:
53            sp1: First species. A dict of {specie/element: amt} as per the
54                definition in Site and PeriodicSite.
55            sp2: Second species. A dict of {specie/element: amt} as per the
56                definition in Site and PeriodicSite.
57
58        Returns:
59            Boolean indicating whether species are considered equal.
60        """
61        return
62
63    @abc.abstractmethod
64    def get_hash(self, composition):
65        """
66        Defines a hash to group structures. This allows structures to be
67        grouped efficiently for comparison. The hash must be invariant under
68        supercell creation. (e.g. composition is not a good hash, but
69        fractional_composition might be). Reduced formula is not a good formula,
70        due to weird behavior with fractional occupancy.
71
72        Composition is used here instead of structure because for anonymous
73        matches it is much quicker to apply a substitution to a composition
74        object than a structure object.
75
76        Args:
77            composition (Composition): composition of the structure
78
79        Returns:
80            A hashable object. Examples can be string formulas, integers etc.
81        """
82        return
83
84    @classmethod
85    def from_dict(cls, d):
86        """
87        :param d: Dict representation
88        :return: Comparator.
89        """
90        for trans_modules in ["structure_matcher"]:
91            mod = __import__(
92                "pymatgen.analysis." + trans_modules,
93                globals(),
94                locals(),
95                [d["@class"]],
96                0,
97            )
98            if hasattr(mod, d["@class"]):
99                trans = getattr(mod, d["@class"])
100                return trans()
101        raise ValueError("Invalid Comparator dict")
102
103    def as_dict(self):
104        """
105        :return: MSONable dict
106        """
107        return {
108            "version": __version__,
109            "@module": self.__class__.__module__,
110            "@class": self.__class__.__name__,
111        }
112
113
114class SpeciesComparator(AbstractComparator):
115    """
116    A Comparator that matches species exactly. The default used in
117    StructureMatcher.
118    """
119
120    def are_equal(self, sp1, sp2):
121        """
122        True if species are exactly the same, i.e., Fe2+ == Fe2+ but not Fe3+.
123
124        Args:
125            sp1: First species. A dict of {specie/element: amt} as per the
126                definition in Site and PeriodicSite.
127            sp2: Second species. A dict of {specie/element: amt} as per the
128                definition in Site and PeriodicSite.
129
130        Returns:
131            Boolean indicating whether species are equal.
132        """
133        return sp1 == sp2
134
135    def get_hash(self, composition):
136        """
137        Returns: Fractional composition
138        """
139        return composition.fractional_composition
140
141
142class SpinComparator(AbstractComparator):
143    """
144    A Comparator that matches magnetic structures to their inverse spins.
145    This comparator is primarily used to filter magnetically ordered
146    structures with opposite spins, which are equivalent.
147    """
148
149    def are_equal(self, sp1, sp2):
150        """
151        True if species are exactly the same, i.e., Fe2+ == Fe2+ but not
152        Fe3+. and the spins are reversed. i.e., spin up maps to spin down,
153        and vice versa.
154
155        Args:
156            sp1: First species. A dict of {specie/element: amt} as per the
157                definition in Site and PeriodicSite.
158            sp2: Second species. A dict of {specie/element: amt} as per the
159                definition in Site and PeriodicSite.
160
161        Returns:
162            Boolean indicating whether species are equal.
163        """
164        for s1 in sp1.keys():
165            spin1 = getattr(s1, "spin", 0)
166            oxi1 = getattr(s1, "oxi_state", 0)
167            for s2 in sp2.keys():
168                spin2 = getattr(s2, "spin", 0)
169                oxi2 = getattr(s2, "oxi_state", 0)
170                if s1.symbol == s2.symbol and oxi1 == oxi2 and spin2 == -spin1:
171                    break
172            else:
173                return False
174        return True
175
176    def get_hash(self, composition):
177        """
178        Returns: Fractional composition
179        """
180        return composition.fractional_composition
181
182
183class ElementComparator(AbstractComparator):
184    """
185    A Comparator that matches elements. i.e. oxidation states are
186    ignored.
187    """
188
189    def are_equal(self, sp1, sp2):
190        """
191        True if element:amounts are exactly the same, i.e.,
192        oxidation state is not considered.
193
194        Args:
195            sp1: First species. A dict of {specie/element: amt} as per the
196                definition in Site and PeriodicSite.
197            sp2: Second species. A dict of {specie/element: amt} as per the
198                definition in Site and PeriodicSite.
199
200        Returns:
201            Boolean indicating whether species are the same based on element
202            and amounts.
203        """
204        comp1 = Composition(sp1)
205        comp2 = Composition(sp2)
206        return comp1.get_el_amt_dict() == comp2.get_el_amt_dict()
207
208    def get_hash(self, composition):
209        """
210        Returns: Fractional element composition
211        """
212        return composition.element_composition.fractional_composition
213
214
215class FrameworkComparator(AbstractComparator):
216    """
217    A Comparator that matches sites, regardless of species.
218    """
219
220    def are_equal(self, sp1, sp2):
221        """
222        True if there are atoms on both sites.
223
224        Args:
225            sp1: First species. A dict of {specie/element: amt} as per the
226                definition in Site and PeriodicSite.
227            sp2: Second species. A dict of {specie/element: amt} as per the
228                definition in Site and PeriodicSite.
229
230        Returns:
231            True always
232        """
233        return True
234
235    def get_hash(self, composition):
236        """
237        No hash possible
238        """
239        return 1
240
241
242class OrderDisorderElementComparator(AbstractComparator):
243    """
244    A Comparator that matches sites, given some overlap in the element
245    composition
246    """
247
248    def are_equal(self, sp1, sp2):
249        """
250        True if there is some overlap in composition between the species
251
252        Args:
253            sp1: First species. A dict of {specie/element: amt} as per the
254                definition in Site and PeriodicSite.
255            sp2: Second species. A dict of {specie/element: amt} as per the
256                definition in Site and PeriodicSite.
257
258        Returns:
259            True always
260        """
261        set1 = set(sp1.elements)
262        set2 = set(sp2.elements)
263        return set1.issubset(set2) or set2.issubset(set1)
264
265    def get_hash(self, composition):
266        """
267        Returns: Fractional composition
268        """
269        return composition.fractional_composition
270
271
272class OccupancyComparator(AbstractComparator):
273    """
274    A Comparator that matches occupancies on sites,
275    irrespective of the species of those sites.
276    """
277
278    def are_equal(self, sp1, sp2):
279        """
280        Args:
281            sp1: First species. A dict of {specie/element: amt} as per the
282                definition in Site and PeriodicSite.
283            sp2: Second species. A dict of {specie/element: amt} as per the
284                definition in Site and PeriodicSite.
285
286        Returns:
287            True if sets of occupancies (amt) are equal on both sites.
288        """
289        return set(sp1.element_composition.values()) == set(sp2.element_composition.values())
290
291    def get_hash(self, composition):
292        """
293        :param composition: Composition.
294        :return: 1. Difficult to define sensible hash
295        """
296        return 1
297
298
299class StructureMatcher(MSONable):
300    """
301    Class to match structures by similarity.
302
303    Algorithm:
304
305    1. Given two structures: s1 and s2
306    2. Optional: Reduce to primitive cells.
307    3. If the number of sites do not match, return False
308    4. Reduce to s1 and s2 to Niggli Cells
309    5. Optional: Scale s1 and s2 to same volume.
310    6. Optional: Remove oxidation states associated with sites
311    7. Find all possible lattice vectors for s2 within shell of ltol.
312    8. For s1, translate an atom in the smallest set to the origin
313    9. For s2: find all valid lattices from permutations of the list
314       of lattice vectors (invalid if: det(Lattice Matrix) < half
315       volume of original s2 lattice)
316    10. For each valid lattice:
317
318        a. If the lattice angles of are within tolerance of s1,
319           basis change s2 into new lattice.
320        b. For each atom in the smallest set of s2:
321
322            i. Translate to origin and compare fractional sites in
323            structure within a fractional tolerance.
324            ii. If true:
325
326                ia. Convert both lattices to cartesian and place
327                both structures on an average lattice
328                ib. Compute and return the average and max rms
329                displacement between the two structures normalized
330                by the average free length per atom
331
332                if fit function called:
333                    if normalized max rms displacement is less than
334                    stol. Return True
335
336                if get_rms_dist function called:
337                    if normalized average rms displacement is less
338                    than the stored rms displacement, store and
339                    continue. (This function will search all possible
340                    lattices for the smallest average rms displacement
341                    between the two structures)
342    """
343
344    def __init__(
345        self,
346        ltol=0.2,
347        stol=0.3,
348        angle_tol=5,
349        primitive_cell=True,
350        scale=True,
351        attempt_supercell=False,
352        allow_subset=False,
353        comparator=SpeciesComparator(),
354        supercell_size="num_sites",
355        ignored_species=None,
356    ):
357        """
358        Args:
359            ltol (float): Fractional length tolerance. Default is 0.2.
360            stol (float): Site tolerance. Defined as the fraction of the
361                average free length per atom := ( V / Nsites ) ** (1/3)
362                Default is 0.3.
363            angle_tol (float): Angle tolerance in degrees. Default is 5 degrees.
364            primitive_cell (bool): If true: input structures will be reduced to
365                primitive cells prior to matching. Default to True.
366            scale (bool): Input structures are scaled to equivalent volume if
367               true; For exact matching, set to False.
368            attempt_supercell (bool): If set to True and number of sites in
369                cells differ after a primitive cell reduction (divisible by an
370                integer) attempts to generate a supercell transformation of the
371                smaller cell which is equivalent to the larger structure.
372            allow_subset (bool): Allow one structure to match to the subset of
373                another structure. Eg. Matching of an ordered structure onto a
374                disordered one, or matching a delithiated to a lithiated
375                structure. This option cannot be combined with
376                attempt_supercell, or with structure grouping.
377            comparator (Comparator): A comparator object implementing an equals
378                method that declares declaring equivalency of sites. Default is
379                SpeciesComparator, which implies rigid species
380                mapping, i.e., Fe2+ only matches Fe2+ and not Fe3+.
381
382                Other comparators are provided, e.g., ElementComparator which
383                matches only the elements and not the species.
384
385                The reason why a comparator object is used instead of
386                supplying a comparison function is that it is not possible to
387                pickle a function, which makes it otherwise difficult to use
388                StructureMatcher with Python's multiprocessing.
389            supercell_size (str or list): Method to use for determining the
390                size of a supercell (if applicable). Possible values are
391                num_sites, num_atoms, volume, or an element or list of elements
392                present in both structures.
393            ignored_species (list): A list of ions to be ignored in matching.
394                Useful for matching structures that have similar frameworks
395                except for certain ions, e.g., Li-ion intercalation frameworks.
396                This is more useful than allow_subset because it allows better
397                control over what species are ignored in the matching.
398        """
399
400        self.ltol = ltol
401        self.stol = stol
402        self.angle_tol = angle_tol
403        self._comparator = comparator
404        self._primitive_cell = primitive_cell
405        self._scale = scale
406        self._supercell = attempt_supercell
407        self._supercell_size = supercell_size
408        self._subset = allow_subset
409        self._ignored_species = [] if ignored_species is None else ignored_species[:]
410
411    def _get_supercell_size(self, s1, s2):
412        """
413        Returns the supercell size, and whether the supercell should
414        be applied to s1. If fu == 1, s1_supercell is returned as
415        true, to avoid ambiguity.
416        """
417        if self._supercell_size == "num_sites":
418            fu = s2.num_sites / s1.num_sites
419        elif self._supercell_size == "num_atoms":
420            fu = s2.composition.num_atoms / s1.composition.num_atoms
421        elif self._supercell_size == "volume":
422            fu = s2.volume / s1.volume
423        elif not isinstance(self._supercell_size, str):
424            s1comp, s2comp = 0, 0
425            for el in self._supercell_size:
426                el = get_el_sp(el)
427                s1comp += s1.composition[el]
428                s2comp += s2.composition[el]
429            fu = s2comp / s1comp
430        else:
431            el = get_el_sp(self._supercell_size)
432            if (el in s2.composition) and (el in s1.composition):
433                fu = s2.composition[el] / s1.composition[el]
434            else:
435                raise ValueError("Invalid argument for supercell_size.")
436
437        if fu < 2 / 3:
438            return int(round(1 / fu)), False
439
440        return int(round(fu)), True
441
442    def _get_lattices(self, target_lattice, s, supercell_size=1):
443        """
444        Yields lattices for s with lengths and angles close to the
445        lattice of target_s. If supercell_size is specified, the
446        returned lattice will have that number of primitive cells
447        in it
448
449        Args:
450            s, target_s: Structure objects
451        """
452        lattices = s.lattice.find_all_mappings(
453            target_lattice,
454            ltol=self.ltol,
455            atol=self.angle_tol,
456            skip_rotation_matrix=True,
457        )
458        for l, _, scale_m in lattices:
459            if abs(abs(np.linalg.det(scale_m)) - supercell_size) < 0.5:
460                yield l, scale_m
461
462    def _get_supercells(self, struct1, struct2, fu, s1_supercell):
463        """
464        Computes all supercells of one structure close to the lattice of the
465        other
466        if s1_supercell == True, it makes the supercells of struct1, otherwise
467        it makes them of s2
468
469        yields: s1, s2, supercell_matrix, average_lattice, supercell_matrix
470        """
471
472        def av_lat(l1, l2):
473            params = (np.array(l1.parameters) + np.array(l2.parameters)) / 2
474            return Lattice.from_parameters(*params)
475
476        def sc_generator(s1, s2):
477            s2_fc = np.array(s2.frac_coords)
478            if fu == 1:
479                cc = np.array(s1.cart_coords)
480                for l, sc_m in self._get_lattices(s2.lattice, s1, fu):
481                    fc = l.get_fractional_coords(cc)
482                    fc -= np.floor(fc)
483                    yield fc, s2_fc, av_lat(l, s2.lattice), sc_m
484            else:
485                fc_init = np.array(s1.frac_coords)
486                for l, sc_m in self._get_lattices(s2.lattice, s1, fu):
487                    fc = np.dot(fc_init, np.linalg.inv(sc_m))
488                    lp = lattice_points_in_supercell(sc_m)
489                    fc = (fc[:, None, :] + lp[None, :, :]).reshape((-1, 3))
490                    fc -= np.floor(fc)
491                    yield fc, s2_fc, av_lat(l, s2.lattice), sc_m
492
493        if s1_supercell:
494            for x in sc_generator(struct1, struct2):
495                yield x
496        else:
497            for x in sc_generator(struct2, struct1):
498                # reorder generator output so s1 is still first
499                yield x[1], x[0], x[2], x[3]
500
501    @classmethod
502    def _cmp_fstruct(cls, s1, s2, frac_tol, mask):
503        """
504        Returns true if a matching exists between s2 and s2
505        under frac_tol. s2 should be a subset of s1
506        """
507        if len(s2) > len(s1):
508            raise ValueError("s1 must be larger than s2")
509        if mask.shape != (len(s2), len(s1)):
510            raise ValueError("mask has incorrect shape")
511
512        return is_coord_subset_pbc(s2, s1, frac_tol, mask)
513
514    @classmethod
515    def _cart_dists(cls, s1, s2, avg_lattice, mask, normalization, lll_frac_tol=None):
516        """
517        Finds a matching in cartesian space. Finds an additional
518        fractional translation vector to minimize RMS distance
519
520        Args:
521            s1, s2: numpy arrays of fractional coordinates. len(s1) >= len(s2)
522            avg_lattice: Lattice on which to calculate distances
523            mask: numpy array of booleans. mask[i, j] = True indicates
524                that s2[i] cannot be matched to s1[j]
525            normalization (float): inverse normalization length
526
527        Returns:
528            Distances from s2 to s1, normalized by (V/Natom) ^ 1/3
529            Fractional translation vector to apply to s2.
530            Mapping from s1 to s2, i.e. with numpy slicing, s1[mapping] => s2
531        """
532        if len(s2) > len(s1):
533            raise ValueError("s1 must be larger than s2")
534        if mask.shape != (len(s2), len(s1)):
535            raise ValueError("mask has incorrect shape")
536
537        # vectors are from s2 to s1
538        vecs, d_2 = pbc_shortest_vectors(avg_lattice, s2, s1, mask, return_d2=True, lll_frac_tol=lll_frac_tol)
539        lin = LinearAssignment(d_2)
540        s = lin.solution  # pylint: disable=E1101
541        short_vecs = vecs[np.arange(len(s)), s]
542        translation = np.average(short_vecs, axis=0)
543        f_translation = avg_lattice.get_fractional_coords(translation)
544        new_d2 = np.sum((short_vecs - translation) ** 2, axis=-1)
545
546        return new_d2 ** 0.5 * normalization, f_translation, s
547
548    def _get_mask(self, struct1, struct2, fu, s1_supercell):
549        """
550        Returns mask for matching struct2 to struct1. If struct1 has sites
551        a b c, and fu = 2, assumes supercells of struct2 will be ordered
552        aabbcc (rather than abcabc)
553
554        Returns:
555        mask, struct1 translation indices, struct2 translation index
556        """
557        mask = np.zeros((len(struct2), len(struct1), fu), dtype=np.bool)
558
559        inner = []
560        for sp2, i in itertools.groupby(enumerate(struct2.species_and_occu), key=lambda x: x[1]):
561            i = list(i)
562            inner.append((sp2, slice(i[0][0], i[-1][0] + 1)))
563
564        for sp1, j in itertools.groupby(enumerate(struct1.species_and_occu), key=lambda x: x[1]):
565            j = list(j)
566            j = slice(j[0][0], j[-1][0] + 1)
567            for sp2, i in inner:
568                mask[i, j, :] = not self._comparator.are_equal(sp1, sp2)
569
570        if s1_supercell:
571            mask = mask.reshape((len(struct2), -1))
572        else:
573            # supercell is of struct2, roll fu axis back to preserve
574            # correct ordering
575            mask = np.rollaxis(mask, 2, 1)
576            mask = mask.reshape((-1, len(struct1)))
577
578        # find the best translation indices
579        i = np.argmax(np.sum(mask, axis=-1))
580        inds = np.where(np.invert(mask[i]))[0]
581        if s1_supercell:
582            # remove the symmetrically equivalent s1 indices
583            inds = inds[::fu]
584        return np.array(mask, dtype=np.int_), inds, i
585
586    def fit(self, struct1, struct2, symmetric=False):
587        """
588        Fit two structures.
589
590        Args:
591            struct1 (Structure): 1st structure
592            struct2 (Structure): 2nd structure
593            symmetric (Bool): Defaults to False
594                If True, check the equality both ways.
595                This only impacts a small percentage of structures
596
597        Returns:
598            True or False.
599        """
600        struct1, struct2 = self._process_species([struct1, struct2])
601
602        if not self._subset and self._comparator.get_hash(struct1.composition) != self._comparator.get_hash(
603            struct2.composition
604        ):
605            return None
606
607        if not symmetric:
608            struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2)
609            match = self._match(struct1, struct2, fu, s1_supercell, break_on_match=True)
610            if match is None:
611                return False
612
613            return match[0] <= self.stol
614
615        struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2)
616        match1 = self._match(struct1, struct2, fu, s1_supercell, break_on_match=True)
617        struct1, struct2 = struct2, struct1
618        struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2)
619        match2 = self._match(struct1, struct2, fu, s1_supercell, break_on_match=True)
620
621        if match1 is None or match2 is None:
622            return False
623
624        return max(match1[0], match2[0]) <= self.stol
625
626    def get_rms_dist(self, struct1, struct2):
627        """
628        Calculate RMS displacement between two structures
629
630        Args:
631            struct1 (Structure): 1st structure
632            struct2 (Structure): 2nd structure
633
634        Returns:
635            rms displacement normalized by (Vol / nsites) ** (1/3)
636            and maximum distance between paired sites. If no matching
637            lattice is found None is returned.
638        """
639        struct1, struct2 = self._process_species([struct1, struct2])
640        struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2)
641        match = self._match(struct1, struct2, fu, s1_supercell, use_rms=True, break_on_match=False)
642
643        if match is None:
644            return None
645
646        return match[0], max(match[1])
647
648    def _process_species(self, structures):
649        copied_structures = []
650        for s in structures:
651            # We need the copies to be actual Structure to work properly, not
652            # subclasses. So do type(s) == Structure.
653            ss = Structure.from_sites(s)
654            if self._ignored_species:
655                ss.remove_species(self._ignored_species)
656            copied_structures.append(ss)
657        return copied_structures
658
659    def _preprocess(self, struct1, struct2, niggli=True):
660        """
661        Rescales, finds the reduced structures (primitive and niggli),
662        and finds fu, the supercell size to make struct1 comparable to
663        s2
664        """
665        struct1 = struct1.copy()
666        struct2 = struct2.copy()
667
668        if niggli:
669            struct1 = struct1.get_reduced_structure(reduction_algo="niggli")
670            struct2 = struct2.get_reduced_structure(reduction_algo="niggli")
671
672        # primitive cell transformation
673        if self._primitive_cell:
674            struct1 = struct1.get_primitive_structure()
675            struct2 = struct2.get_primitive_structure()
676
677        if self._supercell:
678            fu, s1_supercell = self._get_supercell_size(struct1, struct2)
679        else:
680            fu, s1_supercell = 1, True
681        mult = fu if s1_supercell else 1 / fu
682
683        # rescale lattice to same volume
684        if self._scale:
685            ratio = (struct2.volume / (struct1.volume * mult)) ** (1 / 6)
686            nl1 = Lattice(struct1.lattice.matrix * ratio)
687            struct1.lattice = nl1
688            nl2 = Lattice(struct2.lattice.matrix / ratio)
689            struct2.lattice = nl2
690
691        return struct1, struct2, fu, s1_supercell
692
693    def _match(
694        self,
695        struct1,
696        struct2,
697        fu,
698        s1_supercell=True,
699        use_rms=False,
700        break_on_match=False,
701    ):
702        """
703        Matches one struct onto the other
704        """
705        ratio = fu if s1_supercell else 1 / fu
706        if len(struct1) * ratio >= len(struct2):
707            return self._strict_match(
708                struct1,
709                struct2,
710                fu,
711                s1_supercell=s1_supercell,
712                break_on_match=break_on_match,
713                use_rms=use_rms,
714            )
715        return self._strict_match(
716            struct2,
717            struct1,
718            fu,
719            s1_supercell=(not s1_supercell),
720            break_on_match=break_on_match,
721            use_rms=use_rms,
722        )
723
724    def _strict_match(
725        self,
726        struct1,
727        struct2,
728        fu,
729        s1_supercell=True,
730        use_rms=False,
731        break_on_match=False,
732    ):
733        """
734        Matches struct2 onto struct1 (which should contain all sites in
735        struct2).
736
737        Args:
738            struct1, struct2 (Structure): structures to be matched
739            fu (int): size of supercell to create
740            s1_supercell (bool): whether to create the supercell of
741                struct1 (vs struct2)
742            use_rms (bool): whether to minimize the rms of the matching
743            break_on_match (bool): whether to stop search at first
744                valid match
745        """
746        if fu < 1:
747            raise ValueError("fu cannot be less than 1")
748
749        mask, s1_t_inds, s2_t_ind = self._get_mask(struct1, struct2, fu, s1_supercell)
750
751        if mask.shape[0] > mask.shape[1]:
752            raise ValueError("after supercell creation, struct1 must " "have more sites than struct2")
753
754        # check that a valid mapping exists
755        if (not self._subset) and mask.shape[1] != mask.shape[0]:
756            return None
757
758        if LinearAssignment(mask).min_cost > 0:  # pylint: disable=E1101
759            return None
760
761        best_match = None
762        # loop over all lattices
763        for s1fc, s2fc, avg_l, sc_m in self._get_supercells(struct1, struct2, fu, s1_supercell):
764            # compute fractional tolerance
765            normalization = (len(s1fc) / avg_l.volume) ** (1 / 3)
766            inv_abc = np.array(avg_l.reciprocal_lattice.abc)
767            frac_tol = inv_abc * self.stol / (np.pi * normalization)
768            # loop over all translations
769            for s1i in s1_t_inds:
770                t = s1fc[s1i] - s2fc[s2_t_ind]
771                t_s2fc = s2fc + t
772                if self._cmp_fstruct(s1fc, t_s2fc, frac_tol, mask):
773                    inv_lll_abc = np.array(avg_l.get_lll_reduced_lattice().reciprocal_lattice.abc)
774                    lll_frac_tol = inv_lll_abc * self.stol / (np.pi * normalization)
775                    dist, t_adj, mapping = self._cart_dists(s1fc, t_s2fc, avg_l, mask, normalization, lll_frac_tol)
776                    if use_rms:
777                        val = np.linalg.norm(dist) / len(dist) ** 0.5
778                    else:
779                        val = max(dist)
780                    # pylint: disable=E1136
781                    if best_match is None or val < best_match[0]:
782                        total_t = t + t_adj
783                        total_t -= np.round(total_t)
784                        best_match = val, dist, sc_m, total_t, mapping
785                        if (break_on_match or val < 1e-5) and val < self.stol:
786                            return best_match
787
788        if best_match and best_match[0] < self.stol:
789            return best_match
790
791        return None
792
793    def group_structures(self, s_list, anonymous=False):
794        """
795        Given a list of structures, use fit to group
796        them by structural equality.
797
798        Args:
799            s_list ([Structure]): List of structures to be grouped
800            anonymous (bool): Whether to use anonymous mode.
801
802        Returns:
803            A list of lists of matched structures
804            Assumption: if s1 == s2 but s1 != s3, than s2 and s3 will be put
805            in different groups without comparison.
806        """
807        if self._subset:
808            raise ValueError("allow_subset cannot be used with" " group_structures")
809
810        original_s_list = list(s_list)
811        s_list = self._process_species(s_list)
812
813        # Use structure hash to pre-group structures
814        if anonymous:
815
816            def c_hash(c):
817                return c.anonymized_formula
818
819        else:
820            c_hash = self._comparator.get_hash
821
822        def s_hash(s):
823            return c_hash(s[1].composition)
824
825        sorted_s_list = sorted(enumerate(s_list), key=s_hash)
826        all_groups = []
827
828        # For each pre-grouped list of structures, perform actual matching.
829        for k, g in itertools.groupby(sorted_s_list, key=s_hash):
830            unmatched = list(g)
831            while len(unmatched) > 0:
832                i, refs = unmatched.pop(0)
833                matches = [i]
834                if anonymous:
835                    inds = filter(
836                        lambda i: self.fit_anonymous(refs, unmatched[i][1]),
837                        list(range(len(unmatched))),
838                    )
839                else:
840                    inds = filter(
841                        lambda i: self.fit(refs, unmatched[i][1]),
842                        list(range(len(unmatched))),
843                    )
844                inds = list(inds)
845                matches.extend([unmatched[i][0] for i in inds])
846                unmatched = [unmatched[i] for i in range(len(unmatched)) if i not in inds]
847                all_groups.append([original_s_list[i] for i in matches])
848
849        return all_groups
850
851    def as_dict(self):
852        """
853        :return: MSONable dict
854        """
855        return {
856            "version": __version__,
857            "@module": self.__class__.__module__,
858            "@class": self.__class__.__name__,
859            "comparator": self._comparator.as_dict(),
860            "stol": self.stol,
861            "ltol": self.ltol,
862            "angle_tol": self.angle_tol,
863            "primitive_cell": self._primitive_cell,
864            "scale": self._scale,
865            "attempt_supercell": self._supercell,
866            "allow_subset": self._subset,
867            "supercell_size": self._supercell_size,
868            "ignored_species": self._ignored_species,
869        }
870
871    @classmethod
872    def from_dict(cls, d):
873        """
874        :param d: Dict representation
875        :return: StructureMatcher
876        """
877        return StructureMatcher(
878            ltol=d["ltol"],
879            stol=d["stol"],
880            angle_tol=d["angle_tol"],
881            primitive_cell=d["primitive_cell"],
882            scale=d["scale"],
883            attempt_supercell=d["attempt_supercell"],
884            allow_subset=d["allow_subset"],
885            comparator=AbstractComparator.from_dict(d["comparator"]),
886            supercell_size=d["supercell_size"],
887            ignored_species=d["ignored_species"],
888        )
889
890    def _anonymous_match(
891        self,
892        struct1,
893        struct2,
894        fu,
895        s1_supercell=True,
896        use_rms=False,
897        break_on_match=False,
898        single_match=False,
899    ):
900        """
901        Tries all permutations of matching struct1 to struct2.
902        Args:
903            struct1, struct2 (Structure): Preprocessed input structures
904        Returns:
905            List of (mapping, match)
906        """
907        if not isinstance(self._comparator, SpeciesComparator):
908            raise ValueError("Anonymous fitting currently requires SpeciesComparator")
909
910        # check that species lists are comparable
911        sp1 = struct1.composition.elements
912        sp2 = struct2.composition.elements
913        if len(sp1) != len(sp2):
914            return None
915
916        ratio = fu if s1_supercell else 1 / fu
917        swapped = len(struct1) * ratio < len(struct2)
918
919        s1_comp = struct1.composition
920        s2_comp = struct2.composition
921        matches = []
922        for perm in itertools.permutations(sp2):
923            sp_mapping = dict(zip(sp1, perm))
924
925            # do quick check that compositions are compatible
926            mapped_comp = Composition({sp_mapping[k]: v for k, v in s1_comp.items()})
927            if (not self._subset) and (self._comparator.get_hash(mapped_comp) != self._comparator.get_hash(s2_comp)):
928                continue
929
930            mapped_struct = struct1.copy()
931            mapped_struct.replace_species(sp_mapping)
932            if swapped:
933                m = self._strict_match(
934                    struct2,
935                    mapped_struct,
936                    fu,
937                    (not s1_supercell),
938                    use_rms,
939                    break_on_match,
940                )
941            else:
942                m = self._strict_match(mapped_struct, struct2, fu, s1_supercell, use_rms, break_on_match)
943            if m:
944                matches.append((sp_mapping, m))
945                if single_match:
946                    break
947        return matches
948
949    def get_rms_anonymous(self, struct1, struct2):
950        """
951        Performs an anonymous fitting, which allows distinct species in one
952        structure to map to another. E.g., to compare if the Li2O and Na2O
953        structures are similar.
954
955        Args:
956            struct1 (Structure): 1st structure
957            struct2 (Structure): 2nd structure
958
959        Returns:
960            (min_rms, min_mapping)
961            min_rms is the minimum rms distance, and min_mapping is the
962            corresponding minimal species mapping that would map
963            struct1 to struct2. (None, None) is returned if the minimax_rms
964            exceeds the threshold.
965        """
966        struct1, struct2 = self._process_species([struct1, struct2])
967        struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2)
968
969        matches = self._anonymous_match(struct1, struct2, fu, s1_supercell, use_rms=True, break_on_match=False)
970        if matches:
971            best = sorted(matches, key=lambda x: x[1][0])[0]
972            return best[1][0], best[0]
973
974        return None, None
975
976    def get_best_electronegativity_anonymous_mapping(self, struct1, struct2):
977        """
978        Performs an anonymous fitting, which allows distinct species in one
979        structure to map to another. E.g., to compare if the Li2O and Na2O
980        structures are similar. If multiple substitutions are within tolerance
981        this will return the one which minimizes the difference in
982        electronegativity between the matches species.
983
984        Args:
985            struct1 (Structure): 1st structure
986            struct2 (Structure): 2nd structure
987
988        Returns:
989            min_mapping (Dict): Mapping of struct1 species to struct2 species
990        """
991        struct1, struct2 = self._process_species([struct1, struct2])
992        struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2)
993
994        matches = self._anonymous_match(struct1, struct2, fu, s1_supercell, use_rms=True, break_on_match=True)
995
996        if matches:
997            min_X_diff = np.inf
998            for m in matches:
999                X_diff = 0
1000                for k, v in m[0].items():
1001                    X_diff += struct1.composition[k] * (k.X - v.X) ** 2
1002                if X_diff < min_X_diff:
1003                    min_X_diff = X_diff
1004                    best = m[0]
1005            return best
1006
1007        return None
1008
1009    def get_all_anonymous_mappings(self, struct1, struct2, niggli=True, include_dist=False):
1010        """
1011        Performs an anonymous fitting, which allows distinct species in one
1012        structure to map to another. Returns a dictionary of species
1013        substitutions that are within tolerance
1014
1015        Args:
1016            struct1 (Structure): 1st structure
1017            struct2 (Structure): 2nd structure
1018            niggli (bool): Find niggli cell in preprocessing
1019            include_dist (bool): Return the maximin distance with each mapping
1020
1021        Returns:
1022            list of species mappings that map struct1 to struct2.
1023        """
1024        struct1, struct2 = self._process_species([struct1, struct2])
1025        struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2, niggli)
1026
1027        matches = self._anonymous_match(struct1, struct2, fu, s1_supercell, break_on_match=not include_dist)
1028        if matches:
1029            if include_dist:
1030                return [(m[0], m[1][0]) for m in matches]
1031
1032            return [m[0] for m in matches]
1033
1034        return None
1035
1036    def fit_anonymous(self, struct1, struct2, niggli=True):
1037        """
1038        Performs an anonymous fitting, which allows distinct species in one
1039        structure to map to another. E.g., to compare if the Li2O and Na2O
1040        structures are similar.
1041
1042        Args:
1043            struct1 (Structure): 1st structure
1044            struct2 (Structure): 2nd structure
1045
1046        Returns:
1047            True/False: Whether a species mapping can map struct1 to stuct2
1048        """
1049        struct1, struct2 = self._process_species([struct1, struct2])
1050        struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2, niggli)
1051
1052        matches = self._anonymous_match(struct1, struct2, fu, s1_supercell, break_on_match=True, single_match=True)
1053
1054        return bool(matches)
1055
1056    def get_supercell_matrix(self, supercell, struct):
1057        """
1058        Returns the matrix for transforming struct to supercell. This
1059        can be used for very distorted 'supercells' where the primitive cell
1060        is impossible to find
1061        """
1062        if self._primitive_cell:
1063            raise ValueError("get_supercell_matrix cannot be used with the " "primitive cell option")
1064        struct, supercell, fu, s1_supercell = self._preprocess(struct, supercell, False)
1065
1066        if not s1_supercell:
1067            raise ValueError(
1068                "The non-supercell must be put onto the basis" " of the supercell, not the other way around"
1069            )
1070
1071        match = self._match(struct, supercell, fu, s1_supercell, use_rms=True, break_on_match=False)
1072
1073        if match is None:
1074            return None
1075
1076        return match[2]
1077
1078    def get_transformation(self, struct1, struct2):
1079        """
1080        Returns the supercell transformation, fractional translation vector,
1081        and a mapping to transform struct2 to be similar to struct1.
1082
1083        Args:
1084            struct1 (Structure): Reference structure
1085            struct2 (Structure): Structure to transform.
1086
1087        Returns:
1088            supercell (numpy.ndarray(3, 3)): supercell matrix
1089            vector (numpy.ndarray(3)): fractional translation vector
1090            mapping (list(int or None)):
1091                The first len(struct1) items of the mapping vector are the
1092                indices of struct1's corresponding sites in struct2 (or None
1093                if there is no corresponding site), and the other items are
1094                the remaining site indices of struct2.
1095        """
1096        if self._primitive_cell:
1097            raise ValueError("get_transformation cannot be used with the " "primitive cell option")
1098
1099        struct1, struct2 = self._process_species((struct1, struct2))
1100
1101        s1, s2, fu, s1_supercell = self._preprocess(struct1, struct2, False)
1102        ratio = fu if s1_supercell else 1 / fu
1103        if s1_supercell and fu > 1:
1104            raise ValueError("Struct1 must be the supercell, " "not the other way around")
1105
1106        if len(s1) * ratio >= len(s2):
1107            # s1 is superset
1108            match = self._strict_match(s1, s2, fu=fu, s1_supercell=False, use_rms=True, break_on_match=False)
1109            if match is None:
1110                return None
1111            # invert the mapping, since it needs to be from s1 to s2
1112            mapping = [list(match[4]).index(i) if i in match[4] else None for i in range(len(s1))]
1113            return match[2], match[3], mapping
1114        # s2 is superset
1115        match = self._strict_match(s2, s1, fu=fu, s1_supercell=True, use_rms=True, break_on_match=False)
1116        if match is None:
1117            return None
1118        # add sites not included in the mapping
1119        not_included = list(range(len(s2) * fu))
1120        for i in match[4]:
1121            not_included.remove(i)
1122        mapping = list(match[4]) + not_included
1123        return match[2], -match[3], mapping
1124
1125    def get_s2_like_s1(self, struct1, struct2, include_ignored_species=True):
1126        """
1127        Performs transformations on struct2 to put it in a basis similar to
1128        struct1 (without changing any of the inter-site distances)
1129
1130        Args:
1131            struct1 (Structure): Reference structure
1132            struct2 (Structure): Structure to transform.
1133            include_ignored_species (bool): Defaults to True,
1134                the ignored_species is also transformed to the struct1
1135                lattice orientation, though obviously there is no direct
1136                matching to existing sites.
1137
1138        Returns:
1139            A structure object similar to struct1, obtained by making a
1140            supercell, sorting, and translating struct2.
1141        """
1142        s1, s2 = self._process_species([struct1, struct2])
1143        trans = self.get_transformation(s1, s2)
1144        if trans is None:
1145            return None
1146        sc, t, mapping = trans
1147        sites = list(s2)
1148        # Append the ignored sites at the end.
1149        sites.extend([site for site in struct2 if site not in s2])
1150        temp = Structure.from_sites(sites)
1151
1152        temp.make_supercell(sc)
1153        temp.translate_sites(list(range(len(temp))), t)
1154        # translate sites to correct unit cell
1155        for i, j in enumerate(mapping[: len(s1)]):
1156            if j is not None:
1157                vec = np.round(struct1[i].frac_coords - temp[j].frac_coords)
1158                temp.translate_sites(j, vec, to_unit_cell=False)
1159
1160        sites = [temp.sites[i] for i in mapping if i is not None]
1161
1162        if include_ignored_species:
1163            start = int(round(len(temp) / len(struct2) * len(s2)))
1164            sites.extend(temp.sites[start:])
1165
1166        return Structure.from_sites(sites)
1167
1168    def get_mapping(self, superset, subset):
1169        """
1170        Calculate the mapping from superset to subset.
1171
1172        Args:
1173            superset (Structure): Structure containing at least the sites in
1174                subset (within the structure matching tolerance)
1175            subset (Structure): Structure containing some of the sites in
1176                superset (within the structure matching tolerance)
1177
1178        Returns:
1179            numpy array such that superset.sites[mapping] is within matching
1180            tolerance of subset.sites or None if no such mapping is possible
1181        """
1182        if self._supercell:
1183            raise ValueError("cannot compute mapping to supercell")
1184        if self._primitive_cell:
1185            raise ValueError("cannot compute mapping with primitive cell " "option")
1186        if len(subset) > len(superset):
1187            raise ValueError("subset is larger than superset")
1188
1189        superset, subset, _, _ = self._preprocess(superset, subset, True)
1190        match = self._strict_match(superset, subset, 1, break_on_match=False)
1191
1192        if match is None or match[0] > self.stol:
1193            return None
1194
1195        return match[4]
1196
1197
1198class PointDefectComparator(MSONable):
1199    """
1200    A class that matches pymatgen Point Defect objects even if their
1201    cartesian co-ordinates are different (compares sublattices for the defect)
1202
1203    NOTE: for defect complexes (more than a single defect),
1204    this comparator will break.
1205    """
1206
1207    def __init__(self, check_charge=False, check_primitive_cell=False, check_lattice_scale=False):
1208        """
1209        Args:
1210            check_charge (bool): Gives option to check
1211                if charges are identical.
1212                Default is False (different charged defects can be same)
1213            check_primitive_cell (bool): Gives option to
1214                compare different supercells of bulk_structure,
1215                rather than directly compare supercell sizes
1216                Default is False (requires bulk_structure in each defect to be same size)
1217            check_lattice_scale (bool): Gives option to scale volumes of
1218                structures to each other identical lattice constants.
1219                Default is False (enforces same
1220                lattice constants in both structures)
1221        """
1222        self.check_charge = check_charge
1223        self.check_primitive_cell = check_primitive_cell
1224        self.check_lattice_scale = check_lattice_scale
1225
1226    def are_equal(self, d1, d2):
1227        """
1228        Args:
1229            d1: First defect. A pymatgen Defect object.
1230            d2: Second defect. A pymatgen Defect object.
1231
1232        Returns:
1233            True if defects are identical in type and sublattice.
1234        """
1235        possible_defect_types = (Defect, Vacancy, Substitution, Interstitial)
1236
1237        if not isinstance(d1, possible_defect_types) or not isinstance(d2, possible_defect_types):
1238            raise ValueError("Cannot use PointDefectComparator to" " compare non-defect objects...")
1239
1240        if not isinstance(d1, d2.__class__):
1241            return False
1242        if d1.site.specie != d2.site.specie:
1243            return False
1244        if self.check_charge and (d1.charge != d2.charge):
1245            return False
1246
1247        sm = StructureMatcher(
1248            ltol=0.01,
1249            primitive_cell=self.check_primitive_cell,
1250            scale=self.check_lattice_scale,
1251        )
1252
1253        if not sm.fit(d1.bulk_structure, d2.bulk_structure):
1254            return False
1255
1256        d1 = d1.copy()
1257        d2 = d2.copy()
1258        if self.check_primitive_cell or self.check_lattice_scale:
1259            # if allowing for base structure volume or supercell modifications,
1260            # then need to preprocess defect objects to allow for matching
1261            d1_mod_bulk_structure, d2_mod_bulk_structure, _, _ = sm._preprocess(d1.bulk_structure, d2.bulk_structure)
1262            d1_defect_site = PeriodicSite(
1263                d1.site.specie,
1264                d1.site.coords,
1265                d1_mod_bulk_structure.lattice,
1266                to_unit_cell=True,
1267                coords_are_cartesian=True,
1268            )
1269            d2_defect_site = PeriodicSite(
1270                d2.site.specie,
1271                d2.site.coords,
1272                d2_mod_bulk_structure.lattice,
1273                to_unit_cell=True,
1274                coords_are_cartesian=True,
1275            )
1276
1277            d1._structure = d1_mod_bulk_structure
1278            d2._structure = d2_mod_bulk_structure
1279            d1._defect_site = d1_defect_site
1280            d2._defect_site = d2_defect_site
1281
1282        return sm.fit(d1.generate_defect_structure(), d2.generate_defect_structure())
1283