1# coding: utf-8 2# Copyright (c) Pymatgen Development Team. 3# Distributed under the terms of the MIT License. 4 5""" 6This module provides classes to perform fitting of structures. 7""" 8 9import abc 10import itertools 11 12import numpy as np 13from monty.json import MSONable 14 15from pymatgen.analysis.defects.core import Defect, Interstitial, Substitution, Vacancy 16from pymatgen.core import PeriodicSite 17from pymatgen.core.composition import Composition 18from pymatgen.core.lattice import Lattice 19from pymatgen.core.periodic_table import get_el_sp 20from pymatgen.core.structure import Structure 21from pymatgen.optimization.linear_assignment import LinearAssignment # type: ignore 22from pymatgen.util.coord import lattice_points_in_supercell 23from pymatgen.util.coord_cython import ( # type: ignore 24 is_coord_subset_pbc, 25 pbc_shortest_vectors, 26) 27 28__author__ = "William Davidson Richards, Stephen Dacek, Shyue Ping Ong" 29__copyright__ = "Copyright 2011, The Materials Project" 30__version__ = "1.0" 31__maintainer__ = "William Davidson Richards" 32__email__ = "wrichard@mit.edu" 33__status__ = "Production" 34__date__ = "Dec 3, 2012" 35 36 37class AbstractComparator(MSONable, metaclass=abc.ABCMeta): 38 """ 39 Abstract Comparator class. A Comparator defines how sites are compared in 40 a structure. 41 """ 42 43 @abc.abstractmethod 44 def are_equal(self, sp1, sp2): 45 """ 46 Defines how the species of two sites are considered equal. For 47 example, one can consider sites to have the same species only when 48 the species are exactly the same, i.e., Fe2+ matches Fe2+ but not 49 Fe3+. Or one can define that only the element matters, 50 and all oxidation state information are ignored. 51 52 Args: 53 sp1: First species. A dict of {specie/element: amt} as per the 54 definition in Site and PeriodicSite. 55 sp2: Second species. A dict of {specie/element: amt} as per the 56 definition in Site and PeriodicSite. 57 58 Returns: 59 Boolean indicating whether species are considered equal. 60 """ 61 return 62 63 @abc.abstractmethod 64 def get_hash(self, composition): 65 """ 66 Defines a hash to group structures. This allows structures to be 67 grouped efficiently for comparison. The hash must be invariant under 68 supercell creation. (e.g. composition is not a good hash, but 69 fractional_composition might be). Reduced formula is not a good formula, 70 due to weird behavior with fractional occupancy. 71 72 Composition is used here instead of structure because for anonymous 73 matches it is much quicker to apply a substitution to a composition 74 object than a structure object. 75 76 Args: 77 composition (Composition): composition of the structure 78 79 Returns: 80 A hashable object. Examples can be string formulas, integers etc. 81 """ 82 return 83 84 @classmethod 85 def from_dict(cls, d): 86 """ 87 :param d: Dict representation 88 :return: Comparator. 89 """ 90 for trans_modules in ["structure_matcher"]: 91 mod = __import__( 92 "pymatgen.analysis." + trans_modules, 93 globals(), 94 locals(), 95 [d["@class"]], 96 0, 97 ) 98 if hasattr(mod, d["@class"]): 99 trans = getattr(mod, d["@class"]) 100 return trans() 101 raise ValueError("Invalid Comparator dict") 102 103 def as_dict(self): 104 """ 105 :return: MSONable dict 106 """ 107 return { 108 "version": __version__, 109 "@module": self.__class__.__module__, 110 "@class": self.__class__.__name__, 111 } 112 113 114class SpeciesComparator(AbstractComparator): 115 """ 116 A Comparator that matches species exactly. The default used in 117 StructureMatcher. 118 """ 119 120 def are_equal(self, sp1, sp2): 121 """ 122 True if species are exactly the same, i.e., Fe2+ == Fe2+ but not Fe3+. 123 124 Args: 125 sp1: First species. A dict of {specie/element: amt} as per the 126 definition in Site and PeriodicSite. 127 sp2: Second species. A dict of {specie/element: amt} as per the 128 definition in Site and PeriodicSite. 129 130 Returns: 131 Boolean indicating whether species are equal. 132 """ 133 return sp1 == sp2 134 135 def get_hash(self, composition): 136 """ 137 Returns: Fractional composition 138 """ 139 return composition.fractional_composition 140 141 142class SpinComparator(AbstractComparator): 143 """ 144 A Comparator that matches magnetic structures to their inverse spins. 145 This comparator is primarily used to filter magnetically ordered 146 structures with opposite spins, which are equivalent. 147 """ 148 149 def are_equal(self, sp1, sp2): 150 """ 151 True if species are exactly the same, i.e., Fe2+ == Fe2+ but not 152 Fe3+. and the spins are reversed. i.e., spin up maps to spin down, 153 and vice versa. 154 155 Args: 156 sp1: First species. A dict of {specie/element: amt} as per the 157 definition in Site and PeriodicSite. 158 sp2: Second species. A dict of {specie/element: amt} as per the 159 definition in Site and PeriodicSite. 160 161 Returns: 162 Boolean indicating whether species are equal. 163 """ 164 for s1 in sp1.keys(): 165 spin1 = getattr(s1, "spin", 0) 166 oxi1 = getattr(s1, "oxi_state", 0) 167 for s2 in sp2.keys(): 168 spin2 = getattr(s2, "spin", 0) 169 oxi2 = getattr(s2, "oxi_state", 0) 170 if s1.symbol == s2.symbol and oxi1 == oxi2 and spin2 == -spin1: 171 break 172 else: 173 return False 174 return True 175 176 def get_hash(self, composition): 177 """ 178 Returns: Fractional composition 179 """ 180 return composition.fractional_composition 181 182 183class ElementComparator(AbstractComparator): 184 """ 185 A Comparator that matches elements. i.e. oxidation states are 186 ignored. 187 """ 188 189 def are_equal(self, sp1, sp2): 190 """ 191 True if element:amounts are exactly the same, i.e., 192 oxidation state is not considered. 193 194 Args: 195 sp1: First species. A dict of {specie/element: amt} as per the 196 definition in Site and PeriodicSite. 197 sp2: Second species. A dict of {specie/element: amt} as per the 198 definition in Site and PeriodicSite. 199 200 Returns: 201 Boolean indicating whether species are the same based on element 202 and amounts. 203 """ 204 comp1 = Composition(sp1) 205 comp2 = Composition(sp2) 206 return comp1.get_el_amt_dict() == comp2.get_el_amt_dict() 207 208 def get_hash(self, composition): 209 """ 210 Returns: Fractional element composition 211 """ 212 return composition.element_composition.fractional_composition 213 214 215class FrameworkComparator(AbstractComparator): 216 """ 217 A Comparator that matches sites, regardless of species. 218 """ 219 220 def are_equal(self, sp1, sp2): 221 """ 222 True if there are atoms on both sites. 223 224 Args: 225 sp1: First species. A dict of {specie/element: amt} as per the 226 definition in Site and PeriodicSite. 227 sp2: Second species. A dict of {specie/element: amt} as per the 228 definition in Site and PeriodicSite. 229 230 Returns: 231 True always 232 """ 233 return True 234 235 def get_hash(self, composition): 236 """ 237 No hash possible 238 """ 239 return 1 240 241 242class OrderDisorderElementComparator(AbstractComparator): 243 """ 244 A Comparator that matches sites, given some overlap in the element 245 composition 246 """ 247 248 def are_equal(self, sp1, sp2): 249 """ 250 True if there is some overlap in composition between the species 251 252 Args: 253 sp1: First species. A dict of {specie/element: amt} as per the 254 definition in Site and PeriodicSite. 255 sp2: Second species. A dict of {specie/element: amt} as per the 256 definition in Site and PeriodicSite. 257 258 Returns: 259 True always 260 """ 261 set1 = set(sp1.elements) 262 set2 = set(sp2.elements) 263 return set1.issubset(set2) or set2.issubset(set1) 264 265 def get_hash(self, composition): 266 """ 267 Returns: Fractional composition 268 """ 269 return composition.fractional_composition 270 271 272class OccupancyComparator(AbstractComparator): 273 """ 274 A Comparator that matches occupancies on sites, 275 irrespective of the species of those sites. 276 """ 277 278 def are_equal(self, sp1, sp2): 279 """ 280 Args: 281 sp1: First species. A dict of {specie/element: amt} as per the 282 definition in Site and PeriodicSite. 283 sp2: Second species. A dict of {specie/element: amt} as per the 284 definition in Site and PeriodicSite. 285 286 Returns: 287 True if sets of occupancies (amt) are equal on both sites. 288 """ 289 return set(sp1.element_composition.values()) == set(sp2.element_composition.values()) 290 291 def get_hash(self, composition): 292 """ 293 :param composition: Composition. 294 :return: 1. Difficult to define sensible hash 295 """ 296 return 1 297 298 299class StructureMatcher(MSONable): 300 """ 301 Class to match structures by similarity. 302 303 Algorithm: 304 305 1. Given two structures: s1 and s2 306 2. Optional: Reduce to primitive cells. 307 3. If the number of sites do not match, return False 308 4. Reduce to s1 and s2 to Niggli Cells 309 5. Optional: Scale s1 and s2 to same volume. 310 6. Optional: Remove oxidation states associated with sites 311 7. Find all possible lattice vectors for s2 within shell of ltol. 312 8. For s1, translate an atom in the smallest set to the origin 313 9. For s2: find all valid lattices from permutations of the list 314 of lattice vectors (invalid if: det(Lattice Matrix) < half 315 volume of original s2 lattice) 316 10. For each valid lattice: 317 318 a. If the lattice angles of are within tolerance of s1, 319 basis change s2 into new lattice. 320 b. For each atom in the smallest set of s2: 321 322 i. Translate to origin and compare fractional sites in 323 structure within a fractional tolerance. 324 ii. If true: 325 326 ia. Convert both lattices to cartesian and place 327 both structures on an average lattice 328 ib. Compute and return the average and max rms 329 displacement between the two structures normalized 330 by the average free length per atom 331 332 if fit function called: 333 if normalized max rms displacement is less than 334 stol. Return True 335 336 if get_rms_dist function called: 337 if normalized average rms displacement is less 338 than the stored rms displacement, store and 339 continue. (This function will search all possible 340 lattices for the smallest average rms displacement 341 between the two structures) 342 """ 343 344 def __init__( 345 self, 346 ltol=0.2, 347 stol=0.3, 348 angle_tol=5, 349 primitive_cell=True, 350 scale=True, 351 attempt_supercell=False, 352 allow_subset=False, 353 comparator=SpeciesComparator(), 354 supercell_size="num_sites", 355 ignored_species=None, 356 ): 357 """ 358 Args: 359 ltol (float): Fractional length tolerance. Default is 0.2. 360 stol (float): Site tolerance. Defined as the fraction of the 361 average free length per atom := ( V / Nsites ) ** (1/3) 362 Default is 0.3. 363 angle_tol (float): Angle tolerance in degrees. Default is 5 degrees. 364 primitive_cell (bool): If true: input structures will be reduced to 365 primitive cells prior to matching. Default to True. 366 scale (bool): Input structures are scaled to equivalent volume if 367 true; For exact matching, set to False. 368 attempt_supercell (bool): If set to True and number of sites in 369 cells differ after a primitive cell reduction (divisible by an 370 integer) attempts to generate a supercell transformation of the 371 smaller cell which is equivalent to the larger structure. 372 allow_subset (bool): Allow one structure to match to the subset of 373 another structure. Eg. Matching of an ordered structure onto a 374 disordered one, or matching a delithiated to a lithiated 375 structure. This option cannot be combined with 376 attempt_supercell, or with structure grouping. 377 comparator (Comparator): A comparator object implementing an equals 378 method that declares declaring equivalency of sites. Default is 379 SpeciesComparator, which implies rigid species 380 mapping, i.e., Fe2+ only matches Fe2+ and not Fe3+. 381 382 Other comparators are provided, e.g., ElementComparator which 383 matches only the elements and not the species. 384 385 The reason why a comparator object is used instead of 386 supplying a comparison function is that it is not possible to 387 pickle a function, which makes it otherwise difficult to use 388 StructureMatcher with Python's multiprocessing. 389 supercell_size (str or list): Method to use for determining the 390 size of a supercell (if applicable). Possible values are 391 num_sites, num_atoms, volume, or an element or list of elements 392 present in both structures. 393 ignored_species (list): A list of ions to be ignored in matching. 394 Useful for matching structures that have similar frameworks 395 except for certain ions, e.g., Li-ion intercalation frameworks. 396 This is more useful than allow_subset because it allows better 397 control over what species are ignored in the matching. 398 """ 399 400 self.ltol = ltol 401 self.stol = stol 402 self.angle_tol = angle_tol 403 self._comparator = comparator 404 self._primitive_cell = primitive_cell 405 self._scale = scale 406 self._supercell = attempt_supercell 407 self._supercell_size = supercell_size 408 self._subset = allow_subset 409 self._ignored_species = [] if ignored_species is None else ignored_species[:] 410 411 def _get_supercell_size(self, s1, s2): 412 """ 413 Returns the supercell size, and whether the supercell should 414 be applied to s1. If fu == 1, s1_supercell is returned as 415 true, to avoid ambiguity. 416 """ 417 if self._supercell_size == "num_sites": 418 fu = s2.num_sites / s1.num_sites 419 elif self._supercell_size == "num_atoms": 420 fu = s2.composition.num_atoms / s1.composition.num_atoms 421 elif self._supercell_size == "volume": 422 fu = s2.volume / s1.volume 423 elif not isinstance(self._supercell_size, str): 424 s1comp, s2comp = 0, 0 425 for el in self._supercell_size: 426 el = get_el_sp(el) 427 s1comp += s1.composition[el] 428 s2comp += s2.composition[el] 429 fu = s2comp / s1comp 430 else: 431 el = get_el_sp(self._supercell_size) 432 if (el in s2.composition) and (el in s1.composition): 433 fu = s2.composition[el] / s1.composition[el] 434 else: 435 raise ValueError("Invalid argument for supercell_size.") 436 437 if fu < 2 / 3: 438 return int(round(1 / fu)), False 439 440 return int(round(fu)), True 441 442 def _get_lattices(self, target_lattice, s, supercell_size=1): 443 """ 444 Yields lattices for s with lengths and angles close to the 445 lattice of target_s. If supercell_size is specified, the 446 returned lattice will have that number of primitive cells 447 in it 448 449 Args: 450 s, target_s: Structure objects 451 """ 452 lattices = s.lattice.find_all_mappings( 453 target_lattice, 454 ltol=self.ltol, 455 atol=self.angle_tol, 456 skip_rotation_matrix=True, 457 ) 458 for l, _, scale_m in lattices: 459 if abs(abs(np.linalg.det(scale_m)) - supercell_size) < 0.5: 460 yield l, scale_m 461 462 def _get_supercells(self, struct1, struct2, fu, s1_supercell): 463 """ 464 Computes all supercells of one structure close to the lattice of the 465 other 466 if s1_supercell == True, it makes the supercells of struct1, otherwise 467 it makes them of s2 468 469 yields: s1, s2, supercell_matrix, average_lattice, supercell_matrix 470 """ 471 472 def av_lat(l1, l2): 473 params = (np.array(l1.parameters) + np.array(l2.parameters)) / 2 474 return Lattice.from_parameters(*params) 475 476 def sc_generator(s1, s2): 477 s2_fc = np.array(s2.frac_coords) 478 if fu == 1: 479 cc = np.array(s1.cart_coords) 480 for l, sc_m in self._get_lattices(s2.lattice, s1, fu): 481 fc = l.get_fractional_coords(cc) 482 fc -= np.floor(fc) 483 yield fc, s2_fc, av_lat(l, s2.lattice), sc_m 484 else: 485 fc_init = np.array(s1.frac_coords) 486 for l, sc_m in self._get_lattices(s2.lattice, s1, fu): 487 fc = np.dot(fc_init, np.linalg.inv(sc_m)) 488 lp = lattice_points_in_supercell(sc_m) 489 fc = (fc[:, None, :] + lp[None, :, :]).reshape((-1, 3)) 490 fc -= np.floor(fc) 491 yield fc, s2_fc, av_lat(l, s2.lattice), sc_m 492 493 if s1_supercell: 494 for x in sc_generator(struct1, struct2): 495 yield x 496 else: 497 for x in sc_generator(struct2, struct1): 498 # reorder generator output so s1 is still first 499 yield x[1], x[0], x[2], x[3] 500 501 @classmethod 502 def _cmp_fstruct(cls, s1, s2, frac_tol, mask): 503 """ 504 Returns true if a matching exists between s2 and s2 505 under frac_tol. s2 should be a subset of s1 506 """ 507 if len(s2) > len(s1): 508 raise ValueError("s1 must be larger than s2") 509 if mask.shape != (len(s2), len(s1)): 510 raise ValueError("mask has incorrect shape") 511 512 return is_coord_subset_pbc(s2, s1, frac_tol, mask) 513 514 @classmethod 515 def _cart_dists(cls, s1, s2, avg_lattice, mask, normalization, lll_frac_tol=None): 516 """ 517 Finds a matching in cartesian space. Finds an additional 518 fractional translation vector to minimize RMS distance 519 520 Args: 521 s1, s2: numpy arrays of fractional coordinates. len(s1) >= len(s2) 522 avg_lattice: Lattice on which to calculate distances 523 mask: numpy array of booleans. mask[i, j] = True indicates 524 that s2[i] cannot be matched to s1[j] 525 normalization (float): inverse normalization length 526 527 Returns: 528 Distances from s2 to s1, normalized by (V/Natom) ^ 1/3 529 Fractional translation vector to apply to s2. 530 Mapping from s1 to s2, i.e. with numpy slicing, s1[mapping] => s2 531 """ 532 if len(s2) > len(s1): 533 raise ValueError("s1 must be larger than s2") 534 if mask.shape != (len(s2), len(s1)): 535 raise ValueError("mask has incorrect shape") 536 537 # vectors are from s2 to s1 538 vecs, d_2 = pbc_shortest_vectors(avg_lattice, s2, s1, mask, return_d2=True, lll_frac_tol=lll_frac_tol) 539 lin = LinearAssignment(d_2) 540 s = lin.solution # pylint: disable=E1101 541 short_vecs = vecs[np.arange(len(s)), s] 542 translation = np.average(short_vecs, axis=0) 543 f_translation = avg_lattice.get_fractional_coords(translation) 544 new_d2 = np.sum((short_vecs - translation) ** 2, axis=-1) 545 546 return new_d2 ** 0.5 * normalization, f_translation, s 547 548 def _get_mask(self, struct1, struct2, fu, s1_supercell): 549 """ 550 Returns mask for matching struct2 to struct1. If struct1 has sites 551 a b c, and fu = 2, assumes supercells of struct2 will be ordered 552 aabbcc (rather than abcabc) 553 554 Returns: 555 mask, struct1 translation indices, struct2 translation index 556 """ 557 mask = np.zeros((len(struct2), len(struct1), fu), dtype=np.bool) 558 559 inner = [] 560 for sp2, i in itertools.groupby(enumerate(struct2.species_and_occu), key=lambda x: x[1]): 561 i = list(i) 562 inner.append((sp2, slice(i[0][0], i[-1][0] + 1))) 563 564 for sp1, j in itertools.groupby(enumerate(struct1.species_and_occu), key=lambda x: x[1]): 565 j = list(j) 566 j = slice(j[0][0], j[-1][0] + 1) 567 for sp2, i in inner: 568 mask[i, j, :] = not self._comparator.are_equal(sp1, sp2) 569 570 if s1_supercell: 571 mask = mask.reshape((len(struct2), -1)) 572 else: 573 # supercell is of struct2, roll fu axis back to preserve 574 # correct ordering 575 mask = np.rollaxis(mask, 2, 1) 576 mask = mask.reshape((-1, len(struct1))) 577 578 # find the best translation indices 579 i = np.argmax(np.sum(mask, axis=-1)) 580 inds = np.where(np.invert(mask[i]))[0] 581 if s1_supercell: 582 # remove the symmetrically equivalent s1 indices 583 inds = inds[::fu] 584 return np.array(mask, dtype=np.int_), inds, i 585 586 def fit(self, struct1, struct2, symmetric=False): 587 """ 588 Fit two structures. 589 590 Args: 591 struct1 (Structure): 1st structure 592 struct2 (Structure): 2nd structure 593 symmetric (Bool): Defaults to False 594 If True, check the equality both ways. 595 This only impacts a small percentage of structures 596 597 Returns: 598 True or False. 599 """ 600 struct1, struct2 = self._process_species([struct1, struct2]) 601 602 if not self._subset and self._comparator.get_hash(struct1.composition) != self._comparator.get_hash( 603 struct2.composition 604 ): 605 return None 606 607 if not symmetric: 608 struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2) 609 match = self._match(struct1, struct2, fu, s1_supercell, break_on_match=True) 610 if match is None: 611 return False 612 613 return match[0] <= self.stol 614 615 struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2) 616 match1 = self._match(struct1, struct2, fu, s1_supercell, break_on_match=True) 617 struct1, struct2 = struct2, struct1 618 struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2) 619 match2 = self._match(struct1, struct2, fu, s1_supercell, break_on_match=True) 620 621 if match1 is None or match2 is None: 622 return False 623 624 return max(match1[0], match2[0]) <= self.stol 625 626 def get_rms_dist(self, struct1, struct2): 627 """ 628 Calculate RMS displacement between two structures 629 630 Args: 631 struct1 (Structure): 1st structure 632 struct2 (Structure): 2nd structure 633 634 Returns: 635 rms displacement normalized by (Vol / nsites) ** (1/3) 636 and maximum distance between paired sites. If no matching 637 lattice is found None is returned. 638 """ 639 struct1, struct2 = self._process_species([struct1, struct2]) 640 struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2) 641 match = self._match(struct1, struct2, fu, s1_supercell, use_rms=True, break_on_match=False) 642 643 if match is None: 644 return None 645 646 return match[0], max(match[1]) 647 648 def _process_species(self, structures): 649 copied_structures = [] 650 for s in structures: 651 # We need the copies to be actual Structure to work properly, not 652 # subclasses. So do type(s) == Structure. 653 ss = Structure.from_sites(s) 654 if self._ignored_species: 655 ss.remove_species(self._ignored_species) 656 copied_structures.append(ss) 657 return copied_structures 658 659 def _preprocess(self, struct1, struct2, niggli=True): 660 """ 661 Rescales, finds the reduced structures (primitive and niggli), 662 and finds fu, the supercell size to make struct1 comparable to 663 s2 664 """ 665 struct1 = struct1.copy() 666 struct2 = struct2.copy() 667 668 if niggli: 669 struct1 = struct1.get_reduced_structure(reduction_algo="niggli") 670 struct2 = struct2.get_reduced_structure(reduction_algo="niggli") 671 672 # primitive cell transformation 673 if self._primitive_cell: 674 struct1 = struct1.get_primitive_structure() 675 struct2 = struct2.get_primitive_structure() 676 677 if self._supercell: 678 fu, s1_supercell = self._get_supercell_size(struct1, struct2) 679 else: 680 fu, s1_supercell = 1, True 681 mult = fu if s1_supercell else 1 / fu 682 683 # rescale lattice to same volume 684 if self._scale: 685 ratio = (struct2.volume / (struct1.volume * mult)) ** (1 / 6) 686 nl1 = Lattice(struct1.lattice.matrix * ratio) 687 struct1.lattice = nl1 688 nl2 = Lattice(struct2.lattice.matrix / ratio) 689 struct2.lattice = nl2 690 691 return struct1, struct2, fu, s1_supercell 692 693 def _match( 694 self, 695 struct1, 696 struct2, 697 fu, 698 s1_supercell=True, 699 use_rms=False, 700 break_on_match=False, 701 ): 702 """ 703 Matches one struct onto the other 704 """ 705 ratio = fu if s1_supercell else 1 / fu 706 if len(struct1) * ratio >= len(struct2): 707 return self._strict_match( 708 struct1, 709 struct2, 710 fu, 711 s1_supercell=s1_supercell, 712 break_on_match=break_on_match, 713 use_rms=use_rms, 714 ) 715 return self._strict_match( 716 struct2, 717 struct1, 718 fu, 719 s1_supercell=(not s1_supercell), 720 break_on_match=break_on_match, 721 use_rms=use_rms, 722 ) 723 724 def _strict_match( 725 self, 726 struct1, 727 struct2, 728 fu, 729 s1_supercell=True, 730 use_rms=False, 731 break_on_match=False, 732 ): 733 """ 734 Matches struct2 onto struct1 (which should contain all sites in 735 struct2). 736 737 Args: 738 struct1, struct2 (Structure): structures to be matched 739 fu (int): size of supercell to create 740 s1_supercell (bool): whether to create the supercell of 741 struct1 (vs struct2) 742 use_rms (bool): whether to minimize the rms of the matching 743 break_on_match (bool): whether to stop search at first 744 valid match 745 """ 746 if fu < 1: 747 raise ValueError("fu cannot be less than 1") 748 749 mask, s1_t_inds, s2_t_ind = self._get_mask(struct1, struct2, fu, s1_supercell) 750 751 if mask.shape[0] > mask.shape[1]: 752 raise ValueError("after supercell creation, struct1 must " "have more sites than struct2") 753 754 # check that a valid mapping exists 755 if (not self._subset) and mask.shape[1] != mask.shape[0]: 756 return None 757 758 if LinearAssignment(mask).min_cost > 0: # pylint: disable=E1101 759 return None 760 761 best_match = None 762 # loop over all lattices 763 for s1fc, s2fc, avg_l, sc_m in self._get_supercells(struct1, struct2, fu, s1_supercell): 764 # compute fractional tolerance 765 normalization = (len(s1fc) / avg_l.volume) ** (1 / 3) 766 inv_abc = np.array(avg_l.reciprocal_lattice.abc) 767 frac_tol = inv_abc * self.stol / (np.pi * normalization) 768 # loop over all translations 769 for s1i in s1_t_inds: 770 t = s1fc[s1i] - s2fc[s2_t_ind] 771 t_s2fc = s2fc + t 772 if self._cmp_fstruct(s1fc, t_s2fc, frac_tol, mask): 773 inv_lll_abc = np.array(avg_l.get_lll_reduced_lattice().reciprocal_lattice.abc) 774 lll_frac_tol = inv_lll_abc * self.stol / (np.pi * normalization) 775 dist, t_adj, mapping = self._cart_dists(s1fc, t_s2fc, avg_l, mask, normalization, lll_frac_tol) 776 if use_rms: 777 val = np.linalg.norm(dist) / len(dist) ** 0.5 778 else: 779 val = max(dist) 780 # pylint: disable=E1136 781 if best_match is None or val < best_match[0]: 782 total_t = t + t_adj 783 total_t -= np.round(total_t) 784 best_match = val, dist, sc_m, total_t, mapping 785 if (break_on_match or val < 1e-5) and val < self.stol: 786 return best_match 787 788 if best_match and best_match[0] < self.stol: 789 return best_match 790 791 return None 792 793 def group_structures(self, s_list, anonymous=False): 794 """ 795 Given a list of structures, use fit to group 796 them by structural equality. 797 798 Args: 799 s_list ([Structure]): List of structures to be grouped 800 anonymous (bool): Whether to use anonymous mode. 801 802 Returns: 803 A list of lists of matched structures 804 Assumption: if s1 == s2 but s1 != s3, than s2 and s3 will be put 805 in different groups without comparison. 806 """ 807 if self._subset: 808 raise ValueError("allow_subset cannot be used with" " group_structures") 809 810 original_s_list = list(s_list) 811 s_list = self._process_species(s_list) 812 813 # Use structure hash to pre-group structures 814 if anonymous: 815 816 def c_hash(c): 817 return c.anonymized_formula 818 819 else: 820 c_hash = self._comparator.get_hash 821 822 def s_hash(s): 823 return c_hash(s[1].composition) 824 825 sorted_s_list = sorted(enumerate(s_list), key=s_hash) 826 all_groups = [] 827 828 # For each pre-grouped list of structures, perform actual matching. 829 for k, g in itertools.groupby(sorted_s_list, key=s_hash): 830 unmatched = list(g) 831 while len(unmatched) > 0: 832 i, refs = unmatched.pop(0) 833 matches = [i] 834 if anonymous: 835 inds = filter( 836 lambda i: self.fit_anonymous(refs, unmatched[i][1]), 837 list(range(len(unmatched))), 838 ) 839 else: 840 inds = filter( 841 lambda i: self.fit(refs, unmatched[i][1]), 842 list(range(len(unmatched))), 843 ) 844 inds = list(inds) 845 matches.extend([unmatched[i][0] for i in inds]) 846 unmatched = [unmatched[i] for i in range(len(unmatched)) if i not in inds] 847 all_groups.append([original_s_list[i] for i in matches]) 848 849 return all_groups 850 851 def as_dict(self): 852 """ 853 :return: MSONable dict 854 """ 855 return { 856 "version": __version__, 857 "@module": self.__class__.__module__, 858 "@class": self.__class__.__name__, 859 "comparator": self._comparator.as_dict(), 860 "stol": self.stol, 861 "ltol": self.ltol, 862 "angle_tol": self.angle_tol, 863 "primitive_cell": self._primitive_cell, 864 "scale": self._scale, 865 "attempt_supercell": self._supercell, 866 "allow_subset": self._subset, 867 "supercell_size": self._supercell_size, 868 "ignored_species": self._ignored_species, 869 } 870 871 @classmethod 872 def from_dict(cls, d): 873 """ 874 :param d: Dict representation 875 :return: StructureMatcher 876 """ 877 return StructureMatcher( 878 ltol=d["ltol"], 879 stol=d["stol"], 880 angle_tol=d["angle_tol"], 881 primitive_cell=d["primitive_cell"], 882 scale=d["scale"], 883 attempt_supercell=d["attempt_supercell"], 884 allow_subset=d["allow_subset"], 885 comparator=AbstractComparator.from_dict(d["comparator"]), 886 supercell_size=d["supercell_size"], 887 ignored_species=d["ignored_species"], 888 ) 889 890 def _anonymous_match( 891 self, 892 struct1, 893 struct2, 894 fu, 895 s1_supercell=True, 896 use_rms=False, 897 break_on_match=False, 898 single_match=False, 899 ): 900 """ 901 Tries all permutations of matching struct1 to struct2. 902 Args: 903 struct1, struct2 (Structure): Preprocessed input structures 904 Returns: 905 List of (mapping, match) 906 """ 907 if not isinstance(self._comparator, SpeciesComparator): 908 raise ValueError("Anonymous fitting currently requires SpeciesComparator") 909 910 # check that species lists are comparable 911 sp1 = struct1.composition.elements 912 sp2 = struct2.composition.elements 913 if len(sp1) != len(sp2): 914 return None 915 916 ratio = fu if s1_supercell else 1 / fu 917 swapped = len(struct1) * ratio < len(struct2) 918 919 s1_comp = struct1.composition 920 s2_comp = struct2.composition 921 matches = [] 922 for perm in itertools.permutations(sp2): 923 sp_mapping = dict(zip(sp1, perm)) 924 925 # do quick check that compositions are compatible 926 mapped_comp = Composition({sp_mapping[k]: v for k, v in s1_comp.items()}) 927 if (not self._subset) and (self._comparator.get_hash(mapped_comp) != self._comparator.get_hash(s2_comp)): 928 continue 929 930 mapped_struct = struct1.copy() 931 mapped_struct.replace_species(sp_mapping) 932 if swapped: 933 m = self._strict_match( 934 struct2, 935 mapped_struct, 936 fu, 937 (not s1_supercell), 938 use_rms, 939 break_on_match, 940 ) 941 else: 942 m = self._strict_match(mapped_struct, struct2, fu, s1_supercell, use_rms, break_on_match) 943 if m: 944 matches.append((sp_mapping, m)) 945 if single_match: 946 break 947 return matches 948 949 def get_rms_anonymous(self, struct1, struct2): 950 """ 951 Performs an anonymous fitting, which allows distinct species in one 952 structure to map to another. E.g., to compare if the Li2O and Na2O 953 structures are similar. 954 955 Args: 956 struct1 (Structure): 1st structure 957 struct2 (Structure): 2nd structure 958 959 Returns: 960 (min_rms, min_mapping) 961 min_rms is the minimum rms distance, and min_mapping is the 962 corresponding minimal species mapping that would map 963 struct1 to struct2. (None, None) is returned if the minimax_rms 964 exceeds the threshold. 965 """ 966 struct1, struct2 = self._process_species([struct1, struct2]) 967 struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2) 968 969 matches = self._anonymous_match(struct1, struct2, fu, s1_supercell, use_rms=True, break_on_match=False) 970 if matches: 971 best = sorted(matches, key=lambda x: x[1][0])[0] 972 return best[1][0], best[0] 973 974 return None, None 975 976 def get_best_electronegativity_anonymous_mapping(self, struct1, struct2): 977 """ 978 Performs an anonymous fitting, which allows distinct species in one 979 structure to map to another. E.g., to compare if the Li2O and Na2O 980 structures are similar. If multiple substitutions are within tolerance 981 this will return the one which minimizes the difference in 982 electronegativity between the matches species. 983 984 Args: 985 struct1 (Structure): 1st structure 986 struct2 (Structure): 2nd structure 987 988 Returns: 989 min_mapping (Dict): Mapping of struct1 species to struct2 species 990 """ 991 struct1, struct2 = self._process_species([struct1, struct2]) 992 struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2) 993 994 matches = self._anonymous_match(struct1, struct2, fu, s1_supercell, use_rms=True, break_on_match=True) 995 996 if matches: 997 min_X_diff = np.inf 998 for m in matches: 999 X_diff = 0 1000 for k, v in m[0].items(): 1001 X_diff += struct1.composition[k] * (k.X - v.X) ** 2 1002 if X_diff < min_X_diff: 1003 min_X_diff = X_diff 1004 best = m[0] 1005 return best 1006 1007 return None 1008 1009 def get_all_anonymous_mappings(self, struct1, struct2, niggli=True, include_dist=False): 1010 """ 1011 Performs an anonymous fitting, which allows distinct species in one 1012 structure to map to another. Returns a dictionary of species 1013 substitutions that are within tolerance 1014 1015 Args: 1016 struct1 (Structure): 1st structure 1017 struct2 (Structure): 2nd structure 1018 niggli (bool): Find niggli cell in preprocessing 1019 include_dist (bool): Return the maximin distance with each mapping 1020 1021 Returns: 1022 list of species mappings that map struct1 to struct2. 1023 """ 1024 struct1, struct2 = self._process_species([struct1, struct2]) 1025 struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2, niggli) 1026 1027 matches = self._anonymous_match(struct1, struct2, fu, s1_supercell, break_on_match=not include_dist) 1028 if matches: 1029 if include_dist: 1030 return [(m[0], m[1][0]) for m in matches] 1031 1032 return [m[0] for m in matches] 1033 1034 return None 1035 1036 def fit_anonymous(self, struct1, struct2, niggli=True): 1037 """ 1038 Performs an anonymous fitting, which allows distinct species in one 1039 structure to map to another. E.g., to compare if the Li2O and Na2O 1040 structures are similar. 1041 1042 Args: 1043 struct1 (Structure): 1st structure 1044 struct2 (Structure): 2nd structure 1045 1046 Returns: 1047 True/False: Whether a species mapping can map struct1 to stuct2 1048 """ 1049 struct1, struct2 = self._process_species([struct1, struct2]) 1050 struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2, niggli) 1051 1052 matches = self._anonymous_match(struct1, struct2, fu, s1_supercell, break_on_match=True, single_match=True) 1053 1054 return bool(matches) 1055 1056 def get_supercell_matrix(self, supercell, struct): 1057 """ 1058 Returns the matrix for transforming struct to supercell. This 1059 can be used for very distorted 'supercells' where the primitive cell 1060 is impossible to find 1061 """ 1062 if self._primitive_cell: 1063 raise ValueError("get_supercell_matrix cannot be used with the " "primitive cell option") 1064 struct, supercell, fu, s1_supercell = self._preprocess(struct, supercell, False) 1065 1066 if not s1_supercell: 1067 raise ValueError( 1068 "The non-supercell must be put onto the basis" " of the supercell, not the other way around" 1069 ) 1070 1071 match = self._match(struct, supercell, fu, s1_supercell, use_rms=True, break_on_match=False) 1072 1073 if match is None: 1074 return None 1075 1076 return match[2] 1077 1078 def get_transformation(self, struct1, struct2): 1079 """ 1080 Returns the supercell transformation, fractional translation vector, 1081 and a mapping to transform struct2 to be similar to struct1. 1082 1083 Args: 1084 struct1 (Structure): Reference structure 1085 struct2 (Structure): Structure to transform. 1086 1087 Returns: 1088 supercell (numpy.ndarray(3, 3)): supercell matrix 1089 vector (numpy.ndarray(3)): fractional translation vector 1090 mapping (list(int or None)): 1091 The first len(struct1) items of the mapping vector are the 1092 indices of struct1's corresponding sites in struct2 (or None 1093 if there is no corresponding site), and the other items are 1094 the remaining site indices of struct2. 1095 """ 1096 if self._primitive_cell: 1097 raise ValueError("get_transformation cannot be used with the " "primitive cell option") 1098 1099 struct1, struct2 = self._process_species((struct1, struct2)) 1100 1101 s1, s2, fu, s1_supercell = self._preprocess(struct1, struct2, False) 1102 ratio = fu if s1_supercell else 1 / fu 1103 if s1_supercell and fu > 1: 1104 raise ValueError("Struct1 must be the supercell, " "not the other way around") 1105 1106 if len(s1) * ratio >= len(s2): 1107 # s1 is superset 1108 match = self._strict_match(s1, s2, fu=fu, s1_supercell=False, use_rms=True, break_on_match=False) 1109 if match is None: 1110 return None 1111 # invert the mapping, since it needs to be from s1 to s2 1112 mapping = [list(match[4]).index(i) if i in match[4] else None for i in range(len(s1))] 1113 return match[2], match[3], mapping 1114 # s2 is superset 1115 match = self._strict_match(s2, s1, fu=fu, s1_supercell=True, use_rms=True, break_on_match=False) 1116 if match is None: 1117 return None 1118 # add sites not included in the mapping 1119 not_included = list(range(len(s2) * fu)) 1120 for i in match[4]: 1121 not_included.remove(i) 1122 mapping = list(match[4]) + not_included 1123 return match[2], -match[3], mapping 1124 1125 def get_s2_like_s1(self, struct1, struct2, include_ignored_species=True): 1126 """ 1127 Performs transformations on struct2 to put it in a basis similar to 1128 struct1 (without changing any of the inter-site distances) 1129 1130 Args: 1131 struct1 (Structure): Reference structure 1132 struct2 (Structure): Structure to transform. 1133 include_ignored_species (bool): Defaults to True, 1134 the ignored_species is also transformed to the struct1 1135 lattice orientation, though obviously there is no direct 1136 matching to existing sites. 1137 1138 Returns: 1139 A structure object similar to struct1, obtained by making a 1140 supercell, sorting, and translating struct2. 1141 """ 1142 s1, s2 = self._process_species([struct1, struct2]) 1143 trans = self.get_transformation(s1, s2) 1144 if trans is None: 1145 return None 1146 sc, t, mapping = trans 1147 sites = list(s2) 1148 # Append the ignored sites at the end. 1149 sites.extend([site for site in struct2 if site not in s2]) 1150 temp = Structure.from_sites(sites) 1151 1152 temp.make_supercell(sc) 1153 temp.translate_sites(list(range(len(temp))), t) 1154 # translate sites to correct unit cell 1155 for i, j in enumerate(mapping[: len(s1)]): 1156 if j is not None: 1157 vec = np.round(struct1[i].frac_coords - temp[j].frac_coords) 1158 temp.translate_sites(j, vec, to_unit_cell=False) 1159 1160 sites = [temp.sites[i] for i in mapping if i is not None] 1161 1162 if include_ignored_species: 1163 start = int(round(len(temp) / len(struct2) * len(s2))) 1164 sites.extend(temp.sites[start:]) 1165 1166 return Structure.from_sites(sites) 1167 1168 def get_mapping(self, superset, subset): 1169 """ 1170 Calculate the mapping from superset to subset. 1171 1172 Args: 1173 superset (Structure): Structure containing at least the sites in 1174 subset (within the structure matching tolerance) 1175 subset (Structure): Structure containing some of the sites in 1176 superset (within the structure matching tolerance) 1177 1178 Returns: 1179 numpy array such that superset.sites[mapping] is within matching 1180 tolerance of subset.sites or None if no such mapping is possible 1181 """ 1182 if self._supercell: 1183 raise ValueError("cannot compute mapping to supercell") 1184 if self._primitive_cell: 1185 raise ValueError("cannot compute mapping with primitive cell " "option") 1186 if len(subset) > len(superset): 1187 raise ValueError("subset is larger than superset") 1188 1189 superset, subset, _, _ = self._preprocess(superset, subset, True) 1190 match = self._strict_match(superset, subset, 1, break_on_match=False) 1191 1192 if match is None or match[0] > self.stol: 1193 return None 1194 1195 return match[4] 1196 1197 1198class PointDefectComparator(MSONable): 1199 """ 1200 A class that matches pymatgen Point Defect objects even if their 1201 cartesian co-ordinates are different (compares sublattices for the defect) 1202 1203 NOTE: for defect complexes (more than a single defect), 1204 this comparator will break. 1205 """ 1206 1207 def __init__(self, check_charge=False, check_primitive_cell=False, check_lattice_scale=False): 1208 """ 1209 Args: 1210 check_charge (bool): Gives option to check 1211 if charges are identical. 1212 Default is False (different charged defects can be same) 1213 check_primitive_cell (bool): Gives option to 1214 compare different supercells of bulk_structure, 1215 rather than directly compare supercell sizes 1216 Default is False (requires bulk_structure in each defect to be same size) 1217 check_lattice_scale (bool): Gives option to scale volumes of 1218 structures to each other identical lattice constants. 1219 Default is False (enforces same 1220 lattice constants in both structures) 1221 """ 1222 self.check_charge = check_charge 1223 self.check_primitive_cell = check_primitive_cell 1224 self.check_lattice_scale = check_lattice_scale 1225 1226 def are_equal(self, d1, d2): 1227 """ 1228 Args: 1229 d1: First defect. A pymatgen Defect object. 1230 d2: Second defect. A pymatgen Defect object. 1231 1232 Returns: 1233 True if defects are identical in type and sublattice. 1234 """ 1235 possible_defect_types = (Defect, Vacancy, Substitution, Interstitial) 1236 1237 if not isinstance(d1, possible_defect_types) or not isinstance(d2, possible_defect_types): 1238 raise ValueError("Cannot use PointDefectComparator to" " compare non-defect objects...") 1239 1240 if not isinstance(d1, d2.__class__): 1241 return False 1242 if d1.site.specie != d2.site.specie: 1243 return False 1244 if self.check_charge and (d1.charge != d2.charge): 1245 return False 1246 1247 sm = StructureMatcher( 1248 ltol=0.01, 1249 primitive_cell=self.check_primitive_cell, 1250 scale=self.check_lattice_scale, 1251 ) 1252 1253 if not sm.fit(d1.bulk_structure, d2.bulk_structure): 1254 return False 1255 1256 d1 = d1.copy() 1257 d2 = d2.copy() 1258 if self.check_primitive_cell or self.check_lattice_scale: 1259 # if allowing for base structure volume or supercell modifications, 1260 # then need to preprocess defect objects to allow for matching 1261 d1_mod_bulk_structure, d2_mod_bulk_structure, _, _ = sm._preprocess(d1.bulk_structure, d2.bulk_structure) 1262 d1_defect_site = PeriodicSite( 1263 d1.site.specie, 1264 d1.site.coords, 1265 d1_mod_bulk_structure.lattice, 1266 to_unit_cell=True, 1267 coords_are_cartesian=True, 1268 ) 1269 d2_defect_site = PeriodicSite( 1270 d2.site.specie, 1271 d2.site.coords, 1272 d2_mod_bulk_structure.lattice, 1273 to_unit_cell=True, 1274 coords_are_cartesian=True, 1275 ) 1276 1277 d1._structure = d1_mod_bulk_structure 1278 d2._structure = d2_mod_bulk_structure 1279 d1._defect_site = d1_defect_site 1280 d2._defect_site = d2_defect_site 1281 1282 return sm.fit(d1.generate_defect_structure(), d2.generate_defect_structure()) 1283