1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4
5
6"""
7This module implements functions to perform various useful operations on
8entries, such as grouping entries by structure.
9"""
10
11
12import collections
13import csv
14import datetime
15import itertools
16import json
17import logging
18import re
19from typing import Iterable, List, Set, Union
20
21from monty.json import MontyDecoder, MontyEncoder, MSONable
22from monty.string import unicode2str
23
24from pymatgen.analysis.phase_diagram import PDEntry
25from pymatgen.analysis.structure_matcher import SpeciesComparator, StructureMatcher
26from pymatgen.core.composition import Composition
27from pymatgen.core.periodic_table import Element
28from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
29
30logger = logging.getLogger(__name__)
31
32
33def _get_host(structure, species_to_remove):
34    if species_to_remove:
35        s = structure.copy()
36        s.remove_species(species_to_remove)
37        return s
38    return structure
39
40
41def _perform_grouping(args):
42    (
43        entries_json,
44        hosts_json,
45        ltol,
46        stol,
47        angle_tol,
48        primitive_cell,
49        scale,
50        comparator,
51        groups,
52    ) = args
53
54    entries = json.loads(entries_json, cls=MontyDecoder)
55    hosts = json.loads(hosts_json, cls=MontyDecoder)
56    unmatched = list(zip(entries, hosts))
57    while len(unmatched) > 0:
58        ref_host = unmatched[0][1]
59        logger.info("Reference tid = {}, formula = {}".format(unmatched[0][0].entry_id, ref_host.formula))
60        ref_formula = ref_host.composition.reduced_formula
61        logger.info("Reference host = {}".format(ref_formula))
62        matches = [unmatched[0]]
63        for i in range(1, len(unmatched)):
64            test_host = unmatched[i][1]
65            logger.info("Testing tid = {}, formula = {}".format(unmatched[i][0].entry_id, test_host.formula))
66            test_formula = test_host.composition.reduced_formula
67            logger.info("Test host = {}".format(test_formula))
68            m = StructureMatcher(
69                ltol=ltol,
70                stol=stol,
71                angle_tol=angle_tol,
72                primitive_cell=primitive_cell,
73                scale=scale,
74                comparator=comparator,
75            )
76            if m.fit(ref_host, test_host):
77                logger.info("Fit found")
78                matches.append(unmatched[i])
79        groups.append(json.dumps([m[0] for m in matches], cls=MontyEncoder))
80        unmatched = list(filter(lambda x: x not in matches, unmatched))
81        logger.info("{} unmatched remaining".format(len(unmatched)))
82
83
84def group_entries_by_structure(
85    entries,
86    species_to_remove=None,
87    ltol=0.2,
88    stol=0.4,
89    angle_tol=5,
90    primitive_cell=True,
91    scale=True,
92    comparator=SpeciesComparator(),
93    ncpus=None,
94):
95    """
96    Given a sequence of ComputedStructureEntries, use structure fitter to group
97    them by structural similarity.
98
99    Args:
100        entries: Sequence of ComputedStructureEntries.
101        species_to_remove: Sometimes you want to compare a host framework
102            (e.g., in Li-ion battery analysis). This allows you to specify
103            species to remove before structural comparison.
104        ltol (float): Fractional length tolerance. Default is 0.2.
105        stol (float): Site tolerance in Angstrom. Default is 0.4 Angstrom.
106        angle_tol (float): Angle tolerance in degrees. Default is 5 degrees.
107        primitive_cell (bool): If true: input structures will be reduced to
108            primitive cells prior to matching. Defaults to True.
109        scale: Input structures are scaled to equivalent volume if true;
110            For exact matching, set to False.
111        comparator: A comparator object implementing an equals method that
112            declares equivalency of sites. Default is SpeciesComparator,
113            which implies rigid species mapping.
114        ncpus: Number of cpus to use. Use of multiple cpus can greatly improve
115            fitting speed. Default of None means serial processing.
116
117    Returns:
118        Sequence of sequence of entries by structural similarity. e.g,
119        [[ entry1, entry2], [entry3, entry4, entry5]]
120    """
121    start = datetime.datetime.now()
122    logger.info("Started at {}".format(start))
123    entries_host = [(entry, _get_host(entry.structure, species_to_remove)) for entry in entries]
124    if ncpus:
125        symm_entries = collections.defaultdict(list)
126        for entry, host in entries_host:
127            symm_entries[comparator.get_structure_hash(host)].append((entry, host))
128        import multiprocessing as mp
129
130        logging.info("Using {} cpus".format(ncpus))
131        manager = mp.Manager()
132        groups = manager.list()
133        with mp.Pool(ncpus) as p:
134            # Parallel processing only supports Python primitives and not objects.
135            p.map(
136                _perform_grouping,
137                [
138                    (
139                        json.dumps([e[0] for e in eh], cls=MontyEncoder),
140                        json.dumps([e[1] for e in eh], cls=MontyEncoder),
141                        ltol,
142                        stol,
143                        angle_tol,
144                        primitive_cell,
145                        scale,
146                        comparator,
147                        groups,
148                    )
149                    for eh in symm_entries.values()
150                ],
151            )
152    else:
153        groups = []
154        hosts = [host for entry, host in entries_host]
155        _perform_grouping(
156            (
157                json.dumps(entries, cls=MontyEncoder),
158                json.dumps(hosts, cls=MontyEncoder),
159                ltol,
160                stol,
161                angle_tol,
162                primitive_cell,
163                scale,
164                comparator,
165                groups,
166            )
167        )
168    entry_groups = []
169    for g in groups:
170        entry_groups.append(json.loads(g, cls=MontyDecoder))
171    logging.info("Finished at {}".format(datetime.datetime.now()))
172    logging.info("Took {}".format(datetime.datetime.now() - start))
173    return entry_groups
174
175
176class EntrySet(collections.abc.MutableSet, MSONable):
177    """
178    A convenient container for manipulating entries. Allows for generating
179    subsets, dumping into files, etc.
180    """
181
182    def __init__(self, entries: Iterable[Union[PDEntry, ComputedEntry, ComputedStructureEntry]]):
183        """
184        Args:
185            entries: All the entries.
186        """
187        self.entries = set(entries)
188
189    def __contains__(self, item):
190        return item in self.entries
191
192    def __iter__(self):
193        return self.entries.__iter__()
194
195    def __len__(self):
196        return len(self.entries)
197
198    def add(self, element):
199        """
200        Add an entry.
201
202        :param element: Entry
203        """
204        self.entries.add(element)
205
206    def discard(self, element):
207        """
208        Discard an entry.
209
210        :param element: Entry
211        """
212        self.entries.discard(element)
213
214    @property
215    def chemsys(self) -> set:
216        """
217        Returns:
218            set representing the chemical system, e.g., {"Li", "Fe", "P", "O"}
219        """
220        chemsys = set()
221        for e in self.entries:
222            chemsys.update([el.symbol for el in e.composition.keys()])
223        return chemsys
224
225    def remove_non_ground_states(self):
226        """
227        Removes all non-ground state entries, i.e., only keep the lowest energy
228        per atom entry at each composition.
229        """
230        entries = sorted(self.entries, key=lambda e: e.composition.reduced_formula)
231        ground_states = set()
232        for _, g in itertools.groupby(entries, key=lambda e: e.composition.reduced_formula):
233            ground_states.add(min(g, key=lambda e: e.energy_per_atom))
234        self.entries = ground_states
235
236    def get_subset_in_chemsys(self, chemsys: List[str]):
237        """
238        Returns an EntrySet containing only the set of entries belonging to
239        a particular chemical system (in this definition, it includes all sub
240        systems). For example, if the entries are from the
241        Li-Fe-P-O system, and chemsys=["Li", "O"], only the Li, O,
242        and Li-O entries are returned.
243
244        Args:
245            chemsys: Chemical system specified as list of elements. E.g.,
246                ["Li", "O"]
247
248        Returns:
249            EntrySet
250        """
251        chem_sys = set(chemsys)
252        if not chem_sys.issubset(self.chemsys):
253            raise ValueError("%s is not a subset of %s" % (chem_sys, self.chemsys))
254        subset = set()
255        for e in self.entries:
256            elements = [sp.symbol for sp in e.composition.keys()]
257            if chem_sys.issuperset(elements):
258                subset.add(e)
259        return EntrySet(subset)
260
261    def as_dict(self):
262        """
263        :return: MSONable dict
264        """
265        return {"entries": list(self.entries)}
266
267    def to_csv(self, filename: str, latexify_names: bool = False):
268        """
269        Exports PDEntries to a csv
270
271        Args:
272            filename: Filename to write to.
273            entries: PDEntries to export.
274            latexify_names: Format entry names to be LaTex compatible,
275                e.g., Li_{2}O
276        """
277
278        els = set()  # type: Set[Element]
279        for entry in self.entries:
280            els.update(entry.composition.elements)
281        elements = sorted(list(els), key=lambda a: a.X)
282        with open(filename, "w") as f:
283            writer = csv.writer(
284                f,
285                delimiter=unicode2str(","),
286                quotechar=unicode2str('"'),
287                quoting=csv.QUOTE_MINIMAL,
288            )
289            writer.writerow(["Name"] + [el.symbol for el in elements] + ["Energy"])
290            for entry in self.entries:
291                row = [entry.name if not latexify_names else re.sub(r"([0-9]+)", r"_{\1}", entry.name)]
292                row.extend([entry.composition[el] for el in elements])
293                row.append(str(entry.energy))
294                writer.writerow(row)
295
296    @classmethod
297    def from_csv(cls, filename: str):
298        """
299        Imports PDEntries from a csv.
300
301        Args:
302            filename: Filename to import from.
303
304        Returns:
305            List of Elements, List of PDEntries
306        """
307        with open(filename, "r", encoding="utf-8") as f:
308            reader = csv.reader(
309                f,
310                delimiter=unicode2str(","),
311                quotechar=unicode2str('"'),
312                quoting=csv.QUOTE_MINIMAL,
313            )
314            entries = []
315            header_read = False
316            elements = []  # type: List[str]
317            for row in reader:
318                if not header_read:
319                    elements = row[1 : (len(row) - 1)]
320                    header_read = True
321                else:
322                    name = row[0]
323                    energy = float(row[-1])
324                    comp = {}
325                    for ind in range(1, len(row) - 1):
326                        if float(row[ind]) > 0:
327                            comp[Element(elements[ind - 1])] = float(row[ind])
328                    entries.append(PDEntry(Composition(comp), energy, name))
329        return cls(entries)
330