1# coding: utf-8 2# Copyright (c) Pymatgen Development Team. 3# Distributed under the terms of the MIT License. 4 5 6""" 7This module implements functions to perform various useful operations on 8entries, such as grouping entries by structure. 9""" 10 11 12import collections 13import csv 14import datetime 15import itertools 16import json 17import logging 18import re 19from typing import Iterable, List, Set, Union 20 21from monty.json import MontyDecoder, MontyEncoder, MSONable 22from monty.string import unicode2str 23 24from pymatgen.analysis.phase_diagram import PDEntry 25from pymatgen.analysis.structure_matcher import SpeciesComparator, StructureMatcher 26from pymatgen.core.composition import Composition 27from pymatgen.core.periodic_table import Element 28from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry 29 30logger = logging.getLogger(__name__) 31 32 33def _get_host(structure, species_to_remove): 34 if species_to_remove: 35 s = structure.copy() 36 s.remove_species(species_to_remove) 37 return s 38 return structure 39 40 41def _perform_grouping(args): 42 ( 43 entries_json, 44 hosts_json, 45 ltol, 46 stol, 47 angle_tol, 48 primitive_cell, 49 scale, 50 comparator, 51 groups, 52 ) = args 53 54 entries = json.loads(entries_json, cls=MontyDecoder) 55 hosts = json.loads(hosts_json, cls=MontyDecoder) 56 unmatched = list(zip(entries, hosts)) 57 while len(unmatched) > 0: 58 ref_host = unmatched[0][1] 59 logger.info("Reference tid = {}, formula = {}".format(unmatched[0][0].entry_id, ref_host.formula)) 60 ref_formula = ref_host.composition.reduced_formula 61 logger.info("Reference host = {}".format(ref_formula)) 62 matches = [unmatched[0]] 63 for i in range(1, len(unmatched)): 64 test_host = unmatched[i][1] 65 logger.info("Testing tid = {}, formula = {}".format(unmatched[i][0].entry_id, test_host.formula)) 66 test_formula = test_host.composition.reduced_formula 67 logger.info("Test host = {}".format(test_formula)) 68 m = StructureMatcher( 69 ltol=ltol, 70 stol=stol, 71 angle_tol=angle_tol, 72 primitive_cell=primitive_cell, 73 scale=scale, 74 comparator=comparator, 75 ) 76 if m.fit(ref_host, test_host): 77 logger.info("Fit found") 78 matches.append(unmatched[i]) 79 groups.append(json.dumps([m[0] for m in matches], cls=MontyEncoder)) 80 unmatched = list(filter(lambda x: x not in matches, unmatched)) 81 logger.info("{} unmatched remaining".format(len(unmatched))) 82 83 84def group_entries_by_structure( 85 entries, 86 species_to_remove=None, 87 ltol=0.2, 88 stol=0.4, 89 angle_tol=5, 90 primitive_cell=True, 91 scale=True, 92 comparator=SpeciesComparator(), 93 ncpus=None, 94): 95 """ 96 Given a sequence of ComputedStructureEntries, use structure fitter to group 97 them by structural similarity. 98 99 Args: 100 entries: Sequence of ComputedStructureEntries. 101 species_to_remove: Sometimes you want to compare a host framework 102 (e.g., in Li-ion battery analysis). This allows you to specify 103 species to remove before structural comparison. 104 ltol (float): Fractional length tolerance. Default is 0.2. 105 stol (float): Site tolerance in Angstrom. Default is 0.4 Angstrom. 106 angle_tol (float): Angle tolerance in degrees. Default is 5 degrees. 107 primitive_cell (bool): If true: input structures will be reduced to 108 primitive cells prior to matching. Defaults to True. 109 scale: Input structures are scaled to equivalent volume if true; 110 For exact matching, set to False. 111 comparator: A comparator object implementing an equals method that 112 declares equivalency of sites. Default is SpeciesComparator, 113 which implies rigid species mapping. 114 ncpus: Number of cpus to use. Use of multiple cpus can greatly improve 115 fitting speed. Default of None means serial processing. 116 117 Returns: 118 Sequence of sequence of entries by structural similarity. e.g, 119 [[ entry1, entry2], [entry3, entry4, entry5]] 120 """ 121 start = datetime.datetime.now() 122 logger.info("Started at {}".format(start)) 123 entries_host = [(entry, _get_host(entry.structure, species_to_remove)) for entry in entries] 124 if ncpus: 125 symm_entries = collections.defaultdict(list) 126 for entry, host in entries_host: 127 symm_entries[comparator.get_structure_hash(host)].append((entry, host)) 128 import multiprocessing as mp 129 130 logging.info("Using {} cpus".format(ncpus)) 131 manager = mp.Manager() 132 groups = manager.list() 133 with mp.Pool(ncpus) as p: 134 # Parallel processing only supports Python primitives and not objects. 135 p.map( 136 _perform_grouping, 137 [ 138 ( 139 json.dumps([e[0] for e in eh], cls=MontyEncoder), 140 json.dumps([e[1] for e in eh], cls=MontyEncoder), 141 ltol, 142 stol, 143 angle_tol, 144 primitive_cell, 145 scale, 146 comparator, 147 groups, 148 ) 149 for eh in symm_entries.values() 150 ], 151 ) 152 else: 153 groups = [] 154 hosts = [host for entry, host in entries_host] 155 _perform_grouping( 156 ( 157 json.dumps(entries, cls=MontyEncoder), 158 json.dumps(hosts, cls=MontyEncoder), 159 ltol, 160 stol, 161 angle_tol, 162 primitive_cell, 163 scale, 164 comparator, 165 groups, 166 ) 167 ) 168 entry_groups = [] 169 for g in groups: 170 entry_groups.append(json.loads(g, cls=MontyDecoder)) 171 logging.info("Finished at {}".format(datetime.datetime.now())) 172 logging.info("Took {}".format(datetime.datetime.now() - start)) 173 return entry_groups 174 175 176class EntrySet(collections.abc.MutableSet, MSONable): 177 """ 178 A convenient container for manipulating entries. Allows for generating 179 subsets, dumping into files, etc. 180 """ 181 182 def __init__(self, entries: Iterable[Union[PDEntry, ComputedEntry, ComputedStructureEntry]]): 183 """ 184 Args: 185 entries: All the entries. 186 """ 187 self.entries = set(entries) 188 189 def __contains__(self, item): 190 return item in self.entries 191 192 def __iter__(self): 193 return self.entries.__iter__() 194 195 def __len__(self): 196 return len(self.entries) 197 198 def add(self, element): 199 """ 200 Add an entry. 201 202 :param element: Entry 203 """ 204 self.entries.add(element) 205 206 def discard(self, element): 207 """ 208 Discard an entry. 209 210 :param element: Entry 211 """ 212 self.entries.discard(element) 213 214 @property 215 def chemsys(self) -> set: 216 """ 217 Returns: 218 set representing the chemical system, e.g., {"Li", "Fe", "P", "O"} 219 """ 220 chemsys = set() 221 for e in self.entries: 222 chemsys.update([el.symbol for el in e.composition.keys()]) 223 return chemsys 224 225 def remove_non_ground_states(self): 226 """ 227 Removes all non-ground state entries, i.e., only keep the lowest energy 228 per atom entry at each composition. 229 """ 230 entries = sorted(self.entries, key=lambda e: e.composition.reduced_formula) 231 ground_states = set() 232 for _, g in itertools.groupby(entries, key=lambda e: e.composition.reduced_formula): 233 ground_states.add(min(g, key=lambda e: e.energy_per_atom)) 234 self.entries = ground_states 235 236 def get_subset_in_chemsys(self, chemsys: List[str]): 237 """ 238 Returns an EntrySet containing only the set of entries belonging to 239 a particular chemical system (in this definition, it includes all sub 240 systems). For example, if the entries are from the 241 Li-Fe-P-O system, and chemsys=["Li", "O"], only the Li, O, 242 and Li-O entries are returned. 243 244 Args: 245 chemsys: Chemical system specified as list of elements. E.g., 246 ["Li", "O"] 247 248 Returns: 249 EntrySet 250 """ 251 chem_sys = set(chemsys) 252 if not chem_sys.issubset(self.chemsys): 253 raise ValueError("%s is not a subset of %s" % (chem_sys, self.chemsys)) 254 subset = set() 255 for e in self.entries: 256 elements = [sp.symbol for sp in e.composition.keys()] 257 if chem_sys.issuperset(elements): 258 subset.add(e) 259 return EntrySet(subset) 260 261 def as_dict(self): 262 """ 263 :return: MSONable dict 264 """ 265 return {"entries": list(self.entries)} 266 267 def to_csv(self, filename: str, latexify_names: bool = False): 268 """ 269 Exports PDEntries to a csv 270 271 Args: 272 filename: Filename to write to. 273 entries: PDEntries to export. 274 latexify_names: Format entry names to be LaTex compatible, 275 e.g., Li_{2}O 276 """ 277 278 els = set() # type: Set[Element] 279 for entry in self.entries: 280 els.update(entry.composition.elements) 281 elements = sorted(list(els), key=lambda a: a.X) 282 with open(filename, "w") as f: 283 writer = csv.writer( 284 f, 285 delimiter=unicode2str(","), 286 quotechar=unicode2str('"'), 287 quoting=csv.QUOTE_MINIMAL, 288 ) 289 writer.writerow(["Name"] + [el.symbol for el in elements] + ["Energy"]) 290 for entry in self.entries: 291 row = [entry.name if not latexify_names else re.sub(r"([0-9]+)", r"_{\1}", entry.name)] 292 row.extend([entry.composition[el] for el in elements]) 293 row.append(str(entry.energy)) 294 writer.writerow(row) 295 296 @classmethod 297 def from_csv(cls, filename: str): 298 """ 299 Imports PDEntries from a csv. 300 301 Args: 302 filename: Filename to import from. 303 304 Returns: 305 List of Elements, List of PDEntries 306 """ 307 with open(filename, "r", encoding="utf-8") as f: 308 reader = csv.reader( 309 f, 310 delimiter=unicode2str(","), 311 quotechar=unicode2str('"'), 312 quoting=csv.QUOTE_MINIMAL, 313 ) 314 entries = [] 315 header_read = False 316 elements = [] # type: List[str] 317 for row in reader: 318 if not header_read: 319 elements = row[1 : (len(row) - 1)] 320 header_read = True 321 else: 322 name = row[0] 323 energy = float(row[-1]) 324 comp = {} 325 for ind in range(1, len(row) - 1): 326 if float(row[ind]) > 0: 327 comp[Element(elements[ind - 1])] = float(row[ind]) 328 entries.append(PDEntry(Composition(comp), energy, name)) 329 return cls(entries) 330