1"""
2This module provides interfaces with the Materials Project REST API v2 to enable
3the creation of data structures and pymatgen objects using Materials Project data.
4"""
5import sys
6
7from collections import OrderedDict
8from pprint import pprint
9from monty.functools import lazy_property
10from monty.string import marquee
11
12
13from pymatgen.ext.matproj import MPRester, MPRestError
14from abipy.tools.printing import print_dataframe
15from abipy.core.mixins import NotebookWriter
16
17
18MP_DEFAULT_ENDPOINT = "https://materialsproject.org/rest/v2"
19
20MP_KEYS_FOR_DATAFRAME = (
21    "pretty_formula", "e_above_hull", "energy_per_atom",
22    "formation_energy_per_atom", "nsites", "volume",
23    "spacegroup.symbol", "spacegroup.number",
24    "band_gap", "total_magnetization", "material_id" # "unit_cell_formula", "icsd_id", "icsd_ids", "cif", , "tags", "elasticity")
25)
26
27
28def get_mprester(api_key=None, endpoint=None):
29    """
30    Args:
31        api_key (str): A String API key for accessing the MaterialsProject
32            REST interface. Please apply on the Materials Project website for one.
33            If this is None, the code will check if there is a `PMG_MAPI_KEY` in
34            your .pmgrc.yaml. If so, it will use that environment
35            This makes easier for heavy users to simply add
36            this environment variable to their setups and MPRester can
37            then be called without any arguments.
38        endpoint (str): Url of endpoint to access the MaterialsProject REST interface.
39            Defaults to the standard Materials Project REST address, but
40            can be changed to other urls implementing a similar interface.
41    """
42    if api_key is None:
43        try:
44            from pymatgen.core import SETTINGS
45            #from pymatgen.settings import SETTINGS
46        except ImportError:
47            from pymatgen import SETTINGS
48
49        api_key = SETTINGS.get("PMG_MAPI_KEY")
50        if api_key is None:
51            raise RuntimeError("Cannot find PMG_MAPI_KEY in pymatgen settings. Add it to $HOME/.pmgrc.yaml")
52
53    if endpoint is None: endpoint = MP_DEFAULT_ENDPOINT
54    return MyMPRester(api_key=api_key, endpoint=endpoint)
55
56
57class MyMPRester(MPRester):
58    """
59    Subclass Materials project Rester.
60    See :cite:`Jain2013,Ong2015`.
61
62    .. rubric:: Inheritance Diagram
63    .. inheritance-diagram:: MyMPRester
64    """
65    Error = MPRestError
66
67    def get_phasediagram_results(self, elements):
68        """
69        Contact the materials project database, fetch entries and build :class:``PhaseDiagramResults`` instance.
70
71        Args:
72            elements: List of chemical elements.
73        """
74        entries = self.get_entries_in_chemsys(elements, inc_structure="final")
75        return PhaseDiagramResults(entries)
76
77
78class PhaseDiagramResults(object):
79    """
80    Simplified interface to phase-diagram pymatgen API.
81
82    Inspired to:
83
84        https://anaconda.org/matsci/plotting-and-analyzing-a-phase-diagram-using-the-materials-api/notebook
85
86    See also: :cite:`Ong2008,Ong2010`
87    """
88    def __init__(self, entries):
89        self.entries = entries
90        # Convert pymatgen structure to Abipy.
91        from abipy.core.structure import Structure
92        for e in entries:
93            e.structure.__class__ = Structure
94
95        self.structures = [e.structure for e in entries]
96        self.mpids = [e.entry_id for e in entries]
97
98        # Create phase diagram.
99        from pymatgen.analysis.phase_diagram import PhaseDiagram
100        self.phasediagram = PhaseDiagram(self.entries)
101
102    def plot(self, show_unstable=True, show=True):
103        """
104        Plot phase diagram.
105
106        Args:
107            show_unstable (float): Whether unstable phases will be plotted as
108                well as red crosses. If a number > 0 is entered, all phases with
109                ehull < show_unstable will be shown.
110            show: True to show plot.
111
112        Return: plotter object.
113        """
114        from pymatgen.analysis.phase_diagram import PDPlotter
115        plotter = PDPlotter(self.phasediagram, show_unstable=show_unstable)
116        if show:
117            plotter.show()
118        return plotter
119
120    @lazy_property
121    def dataframe(self):
122        """Pandas dataframe with the most important results."""
123        rows = []
124        for e in self.entries:
125            d = e.structure.get_dict4pandas(with_spglib=True)
126            decomp, ehull = self.phasediagram.get_decomp_and_e_above_hull(e)
127
128            rows.append(OrderedDict([
129                ("Materials ID", e.entry_id),
130                ("spglib_symb", d["spglib_symb"]), ("spglib_num", d["spglib_num"]),
131                ("Composition", e.composition.reduced_formula),
132                ("Ehull", ehull), # ("Equilibrium_reaction_energy", pda.get_equilibrium_reaction_energy(e)),
133                ("Decomposition", " + ".join(["%.2f %s" % (v, k.composition.formula) for k, v in decomp.items()])),
134            ]))
135
136        import pandas as pd
137        return pd.DataFrame(rows, columns=list(rows[0].keys()) if rows else None)
138
139    def print_dataframes(self, with_spglib=False, file=sys.stdout, verbose=0):
140        """
141        Print pandas dataframe to file `file`.
142
143        Args:
144            with_spglib: True to compute spacegroup with spglib.
145            file: Output stream.
146            verbose: Verbosity level.
147        """
148        print_dataframe(self.dataframe, file=file)
149        if verbose:
150            from abipy.core.structure import dataframes_from_structures
151            dfs = dataframes_from_structures(self.structures, index=self.mpids, with_spglib=with_spglib)
152            print_dataframe(dfs.lattice, title="Lattice parameters:", file=file)
153            if verbose > 1:
154                print_dataframe(dfs.coords, title="Atomic positions (columns give the site index):", file=file)
155
156
157class DatabaseStructures(NotebookWriter):
158    """
159    Store the results of a query to the MP database.
160    This object is immutable, use add_entry to create a new instance.
161    """
162
163    def __init__(self, structures, ids, data=None):
164        """
165        Args:
166            structures: List of structure objects
167            ids: List of database ids.
168            data: List of dictionaries with data associated to the structures (optional).
169        """
170        from abipy.core.structure import Structure
171        self.structures = list(map(Structure.as_structure, structures))
172        self.ids, self.data = ids, data
173        assert len(self.structures) == len(ids)
174        if data is not None:
175            assert len(self.structures) == len(data)
176
177    def __bool__(self):
178        """bool(self)"""
179        return bool(self.structures)
180    __nonzero__ = __bool__  # py2
181
182    def filter_by_spgnum(self, spgnum):
183        """Filter structures by space group number. Return new MpStructures object."""
184        inds = [i for i, s in enumerate(self.structures) if s.get_space_group_info()[1] == int(spgnum)]
185        new_data = None if self.data is None else [self.data[i] for i in inds]
186        return self.__class__([self.structures[i] for i in inds], [self.ids[i] for i in inds], data=new_data)
187
188    def add_entry(self, structure, entry_id, data_dict=None):
189        """
190        Add new entry, return new object.
191
192        Args:
193           structure: New structure object.
194           entry_id: ID associated to new structure.
195           data_dict: Option dictionary with metadata.
196        """
197        if data_dict is None:
198            new_data = None if self.data is None else self.data + [{}]
199        else:
200            assert self.data is not None
201            new_data = self.data + [data_dict]
202
203        return self.__class__(self.structures + [structure], self.ids + [entry_id], data=new_data)
204
205    @property
206    def lattice_dataframe(self):
207        """pandas DataFrame with lattice parameters."""
208        return self.structure_dataframes.lattice
209
210    @property
211    def coords_dataframe(self):
212        """pandas DataFrame with atomic positions."""
213        return self.structure_dataframes.coords
214
215    @lazy_property
216    def structure_dataframes(self):
217        """Pandas dataframes constructed from self.structures."""
218        from abipy.core.structure import dataframes_from_structures
219        return dataframes_from_structures(self.structures, index=self.ids, with_spglib=True)
220
221    def print_results(self, fmt="abivars", verbose=0, file=sys.stdout):
222        """
223        Print pandas dataframe, structures using format `fmt`, and data to file `file`.
224        `fmt` is automaticall set to `cif` if structure is disordered.
225        Set fmt to None or empty string to disable structure output.
226        """
227        print("\n# Found %s structures in %s database (use `verbose` to get further info)\n"
228              % (len(self.structures), self.dbname), file=file)
229
230        if self.dataframe is not None: print_dataframe(self.dataframe, file=file)
231        if verbose and self.data is not None: pprint(self.data, stream=file)
232
233        # Print structures
234        print_structures = not (fmt is None or str(fmt) == "None")
235        if print_structures:
236            for i, structure in enumerate(self.structures):
237                print(" ", file=file)
238                print(marquee("%s input for %s" % (fmt, self.ids[i]), mark="#"), file=file)
239                print("# " + structure.spget_summary(verbose=verbose).replace("\n", "\n# ") + "\n", file=file)
240                if not structure.is_ordered:
241                    print(structure.convert(fmt="cif"), file=file)
242                else:
243                    print(structure.convert(fmt=fmt), file=file)
244                print(2 * "\n", file=file)
245
246        if len(self.structures) > 10:
247            # Print info again
248            print("\n# Found %s structures in %s database (use `verbose` to get further info)\n"
249                  % (len(self.structures), self.dbname), file=file)
250
251    def yield_figs(self, **kwargs):  # pragma: no cover
252        """NOP required by NotebookWriter protocol."""
253        yield None
254
255    def write_notebook(self, nbpath=None, title=None):
256        """
257        Write a jupyter notebook to nbpath. If nbpath is None, a temporay file in the current
258        working directory is created. Return path to the notebook.
259        """
260        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=title)
261
262        # Use pickle files for data persistence.
263        tmpfile = self.pickle_dump()
264
265        nb.cells.extend([
266            #nbv.new_markdown_cell("# This is a markdown cell"),
267            nbv.new_code_cell("dbs = abilab.restapi.DatabaseStructures.pickle_load('%s')" % tmpfile),
268            nbv.new_code_cell("import qgrid"),
269            nbv.new_code_cell("# dbs.print_results(fmt='cif', verbose=0)"),
270            nbv.new_code_cell("# qgrid.show_grid(dbs.lattice_dataframe)"),
271            nbv.new_code_cell("# qgrid.show_grid(dbs.coords_dataframe)"),
272            nbv.new_code_cell("qgrid.show_grid(dbs.dataframe)"),
273        ])
274
275        return self._write_nb_nbpath(nb, nbpath)
276
277
278class MpStructures(DatabaseStructures):
279    """
280    Store the results of a query to the Materials Project database.
281
282    .. inheritance-diagram:: MpStructures
283    """
284    dbname = "Materials Project"
285
286    @lazy_property
287    def dataframe(self):
288        """Pandas dataframe constructed from self.data. None if data is not available."""
289        if not self.data: return None
290        import pandas as pd
291        rows = []
292        for d, structure in zip(self.data, self.structures):
293            d = Dotdict(d)
294            d = OrderedDict([(k, d.dotget(k, default=None)) for k in MP_KEYS_FOR_DATAFRAME])
295            # Add lattice parameters.
296            for k in ("a", "b", "c", "alpha", "beta", "gamma"):
297                d[k] = getattr(structure.lattice, k)
298            rows.append(d)
299
300        return pd.DataFrame(rows, index=self.ids, columns=list(rows[0].keys()))
301
302    def open_browser(self, browser=None, limit=10):
303        """
304        Args:
305            browser: Open webpage in ``browser``. Use default if $BROWSER if None.
306            limit: Max number of tabs opened in browser. None for no limit.
307        """
308        import webbrowser
309        import cgi
310        for i, mpid in enumerate(self.ids):
311            if limit is not None and i >= limit:
312                print("Found %d structures found. Won't open more than %d tabs" % (len(self.ids), limit))
313                break
314            # https://materialsproject.org/materials/mp-2172/
315            url = "https://materialsproject.org/materials/%s/" % mpid
316            webbrowser.get(browser).open_new_tab(cgi.escape(url))
317
318
319class CodStructures(DatabaseStructures):
320    """
321    Store the results of a query to the COD_ database :cite:`Grazulis2011`.
322
323    .. inheritance-diagram:: CodStructures
324    """
325    dbname = "COD"
326
327    @lazy_property
328    def dataframe(self):
329        """
330        |pandas-Dataframe| constructed. Essentially geometrical info and space groups found by spglib_
331        as COD API is rather limited.
332        """
333        df = self.lattice_dataframe.copy()
334        # Add space group from COD
335        df["cod_sg"] = [d.get("sg", "").replace(" ", "") for d in self.data]
336        return df
337
338
339class Dotdict(dict):
340
341    def dotget(self, key, default=None):
342        """
343        d.dotget["foo.bar"] --> d["foo"]["bar"] if "foo.bar" not in self
344        """
345        # if key is in dict access as normal
346        if key in self:
347            return super().__getitem__(key)
348
349        # Assume string
350        i = -1
351        try:
352            i = key.find(".")
353            if i == -1: return default
354        except AttributeError:
355            return default
356
357        try:
358            root, key = key[:i], key[i+1:]
359            if key == ".": return None
360            return Dotdict(**self[root])[key]
361        except Exception:
362            return None
363