1# coding: utf-8
2"""History file with structural relaxation results."""
3import os
4import numpy as np
5import pymatgen.core.units as units
6
7from collections import OrderedDict
8from monty.functools import lazy_property
9from monty.collections import AttrDict
10from monty.string import marquee, list_strings
11from pymatgen.core.periodic_table import Element
12from pymatgen.analysis.structure_analyzer import RelaxationAnalyzer
13from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_visible
14from abipy.core.structure import Structure
15from abipy.core.mixins import AbinitNcFile, NotebookWriter
16from abipy.abio.robots import Robot
17from abipy.iotools import ETSF_Reader
18import abipy.core.abinit_units as abu
19
20
21class HistFile(AbinitNcFile, NotebookWriter):
22    """
23    File with the history of a structural relaxation or molecular dynamics calculation.
24
25    Usage example:
26
27    .. code-block:: python
28
29        with HistFile("foo_HIST") as hist:
30            hist.plot()
31
32
33    .. rubric:: Inheritance Diagram
34    .. inheritance-diagram:: HistFile
35    """
36    @classmethod
37    def from_file(cls, filepath):
38        """Initialize the object from a netcdf_ file"""
39        return cls(filepath)
40
41    def __init__(self, filepath):
42        super().__init__(filepath)
43        self.reader = HistReader(filepath)
44
45    def close(self):
46        """Close the file."""
47        self.reader.close()
48
49    @lazy_property
50    def params(self):
51        """:class:`OrderedDict` with parameters that might be subject to convergence studies."""
52        return {}
53
54    def __str__(self):
55        return self.to_string()
56
57    # TODO: Add more metadata.
58    #@lazy_property
59    #def nsppol(self):
60    #    """Number of independent spins."""
61    #    return self.reader.read_dimvalue("nsppol")
62
63    #@lazy_property
64    #def nspden(self):
65    #    """Number of independent spin densities."""
66    #    return self.reader.read_dimvalue("nspden")
67
68    #@lazy_property
69    #def nspinor(self):
70    #    """Number of spinor components."""
71    #    return self.reader.read_dimvalue("nspinor")
72
73    @lazy_property
74    def final_energy(self):
75        """Total energy in eV of the last iteration."""
76        return self.etotals[-1]
77
78    @lazy_property
79    def final_pressure(self):
80        """Final pressure in Gpa."""
81        cart_stress_tensors, pressures = self.reader.read_cart_stress_tensors()
82        return pressures[-1]
83
84    #@lazy_property
85    #def final_max_force(self):
86
87    def get_fstats_dict(self, step):
88        """
89        Return |AttrDict| with stats on the forces at the given ``step``.
90        """
91        # [time, natom, 3]
92        var = self.reader.read_variable("fcart")
93        forces = units.ArrayWithUnit(var[step], "Ha bohr^-1").to("eV ang^-1")
94        fmods = np.array([np.linalg.norm(force) for force in forces])
95
96        return AttrDict(
97            fmin=fmods.min(),
98            fmax=fmods.max(),
99            fmean=fmods.mean(),
100            fstd=fmods.std(),
101            drift=np.linalg.norm(forces.sum(axis=0)),
102        )
103
104    def to_string(self, verbose=0, title=None):
105        """String representation."""
106        lines = []; app = lines.append
107        if title is not None: app(marquee(title, mark="="))
108
109        app(marquee("File Info", mark="="))
110        app(self.filestat(as_string=True))
111        app("")
112        app(self.initial_structure.to_string(verbose=verbose, title="Initial Structure"))
113        app("")
114        app("Number of relaxation steps performed: %d" % self.num_steps)
115        app(self.final_structure.to_string(verbose=verbose, title="Final structure"))
116        app("")
117
118        an = self.get_relaxation_analyzer()
119        app("Volume change in percentage: %.2f%%" % (an.get_percentage_volume_change() * 100))
120        d = an.get_percentage_lattice_parameter_changes()
121        vals = tuple(d[k] * 100 for k in ("a", "b", "c"))
122        app("Percentage lattice parameter changes:\n\ta: %.2f%%, b: %.2f%%, c: %2.f%%" % vals)
123        #an.get_percentage_bond_dist_changes(max_radius=3.0)
124        app("")
125
126        cart_stress_tensors, pressures = self.reader.read_cart_stress_tensors()
127        app("Stress tensor (Cartesian coordinates in GPa):\n%s" % cart_stress_tensors[-1])
128        app("Pressure: %.3f [GPa]" % pressures[-1])
129
130        return "\n".join(lines)
131
132    @property
133    def num_steps(self):
134        """Number of iterations performed."""
135        return self.reader.num_steps
136
137    @lazy_property
138    def steps(self):
139        """Step indices."""
140        return list(range(self.num_steps))
141
142    @property
143    def initial_structure(self):
144        """The initial |Structure|."""
145        return self.structures[0]
146
147    @property
148    def final_structure(self):
149        """The |Structure| of the last iteration."""
150        return self.structures[-1]
151
152    @lazy_property
153    def structures(self):
154        """List of |Structure| objects at the different steps."""
155        return self.reader.read_all_structures()
156
157    @lazy_property
158    def etotals(self):
159        """|numpy-array| with total energies in eV at the different steps."""
160        return self.reader.read_eterms().etotals
161
162    def get_relaxation_analyzer(self):
163        """
164        Return a pymatgen :class:`RelaxationAnalyzer` object to analyze the relaxation in a calculation.
165        """
166        return RelaxationAnalyzer(self.initial_structure, self.final_structure)
167
168    def to_xdatcar(self, filepath=None, groupby_type=True, to_unit_cell=False, **kwargs):
169        """
170        Return Xdatcar pymatgen object. See write_xdatcar for the meaning of arguments.
171
172        Args:
173            to_unit_cell (bool): Whether to translate sites into the unit cell.
174            kwargs: keywords arguments passed to Xdatcar constructor.
175        """
176        filepath = self.write_xdatcar(filepath=filepath, groupby_type=groupby_type,
177                                      to_unit_cell=to_unit_cell, overwrite=True)
178        from pymatgen.io.vasp.outputs import Xdatcar
179        return Xdatcar(filepath, **kwargs)
180
181    def write_xdatcar(self, filepath="XDATCAR", groupby_type=True, overwrite=False, to_unit_cell=False):
182        """
183        Write Xdatcar file with unit cell and atomic positions to file ``filepath``.
184
185        Args:
186            filepath: Xdatcar filename. If None, a temporary file is created.
187            groupby_type: If True, atoms are grouped by type. Note that this option
188                may change the order of the atoms. This option is needed because
189                there are post-processing tools (e.g. ovito) that do not work as expected
190                if the atoms in the structure are not grouped by type.
191            overwrite: raise RuntimeError, if False and filepath exists.
192            to_unit_cell (bool): Whether to translate sites into the unit cell.
193
194        Return:
195            path to Xdatcar file.
196        """
197        if filepath is not None and os.path.exists(filepath) and not overwrite:
198            raise RuntimeError("Cannot overwrite pre-existing file `%s`" % filepath)
199        if filepath is None:
200            import tempfile
201            fd, filepath = tempfile.mkstemp(text=True, suffix="_XDATCAR")
202
203        # int typat[natom], double znucl[npsp]
204        # NB: typat is double in the HIST.nc file
205        typat = self.reader.read_value("typat").astype(int)
206        znucl = self.reader.read_value("znucl")
207        ntypat = self.reader.read_dimvalue("ntypat")
208        num_pseudos = self.reader.read_dimvalue("npsp")
209        if num_pseudos != ntypat:
210            raise NotImplementedError("Alchemical mixing is not supported, num_pseudos != ntypat")
211        #print("znucl:", znucl, "\ntypat:", typat)
212
213        symb2pos = OrderedDict()
214        symbols_atom = []
215        for iatom, itype in enumerate(typat):
216            itype = itype - 1
217            symbol = Element.from_Z(int(znucl[itype])).symbol
218            if symbol not in symb2pos: symb2pos[symbol] = []
219            symb2pos[symbol].append(iatom)
220            symbols_atom.append(symbol)
221
222        if not groupby_type:
223            group_ids = np.arange(self.reader.natom)
224        else:
225            group_ids = []
226            for pos_list in symb2pos.values():
227                group_ids.extend(pos_list)
228            group_ids = np.array(group_ids, dtype=int)
229
230        comment = " %s\n" % self.initial_structure.formula
231        with open(filepath, "wt") as fh:
232            # comment line  + scaling factor set to 1.0
233            fh.write(comment)
234            fh.write("1.0\n")
235            for vec in self.initial_structure.lattice.matrix:
236                fh.write("%.12f %.12f %.12f\n" % (vec[0], vec[1], vec[2]))
237            if not groupby_type:
238                fh.write(" ".join(symbols_atom) + "\n")
239                fh.write("1 " * len(symbols_atom) + "\n")
240            else:
241                fh.write(" ".join(symb2pos.keys()) + "\n")
242                fh.write(" ".join(str(len(p)) for p in symb2pos.values()) + "\n")
243
244            # Write atomic positions in reduced coordinates.
245            xred_list = self.reader.read_value("xred")
246            if to_unit_cell:
247                xred_list = xred_list % 1
248
249            for step in range(self.num_steps):
250                fh.write("Direct configuration= %d\n" % (step + 1))
251                frac_coords = xred_list[step, group_ids]
252                for fs in frac_coords:
253                    fh.write("%.12f %.12f %.12f\n" % (fs[0], fs[1], fs[2]))
254
255        return filepath
256
257    def visualize(self, appname="ovito", to_unit_cell=False):  # pragma: no cover
258        """
259        Visualize the crystalline structure with visualizer.
260        See :class:`Visualizer` for the list of applications and formats supported.
261
262        Args:
263            to_unit_cell (bool): Whether to translate sites into the unit cell.
264        """
265        if appname == "mayavi": return self.mayaview()
266
267        # Get the Visualizer subclass from the string.
268        from abipy.iotools import Visualizer
269        visu = Visualizer.from_name(appname)
270        if visu.name != "ovito":
271            raise NotImplementedError("visualizer: %s" % visu.name)
272
273        filepath = self.write_xdatcar(filepath=None, groupby_type=True, to_unit_cell=to_unit_cell)
274
275        return visu(filepath)()
276        #if options.trajectories:
277        #    hist.mvplot_trajectories()
278        #else:
279        #    hist.mvanimate()
280
281    def plot_ax(self, ax, what, fontsize=8, **kwargs):
282        """
283        Helper function to plot quantity ``what`` on axis ``ax``.
284
285        Args:
286            fontsize: fontsize for legend.
287            kwargs are passed to matplotlib plot method.
288        """
289        label = None
290        if what == "energy":
291            # Total energy in eV.
292            marker = kwargs.pop("marker", "o")
293            label = kwargs.pop("label", "Energy")
294            ax.plot(self.steps, self.etotals, label=label, marker=marker, **kwargs)
295            ax.set_ylabel('Energy (eV)')
296
297        elif what == "abc":
298            # Lattice parameters.
299            mark = kwargs.pop("marker", None)
300            markers = ["o", "^", "v"] if mark is None else 3 * [mark]
301            for i, label in enumerate(["a", "b", "c"]):
302                ax.plot(self.steps, [s.lattice.abc[i] for s in self.structures], label=label,
303                        marker=markers[i], **kwargs)
304            ax.set_ylabel("abc (A)")
305
306        elif what in ("a", "b", "c"):
307            i = ("a", "b", "c").index(what)
308            marker = kwargs.pop("marker", None)
309            if marker is None:
310                marker = {"a": "o", "b": "^", "c": "v"}[what]
311            label = kwargs.pop("label", what)
312            ax.plot(self.steps, [s.lattice.abc[i] for s in self.structures], label=label,
313                    marker=marker, **kwargs)
314            ax.set_ylabel('%s (A)' % what)
315
316        elif what == "angles":
317            # Lattice Angles
318            mark = kwargs.pop("marker", None)
319            markers = ["o", "^", "v"] if mark is None else 3 * [mark]
320            for i, label in enumerate(["alpha", "beta", "gamma"]):
321                ax.plot(self.steps, [s.lattice.angles[i] for s in self.structures], label=label,
322                        marker=markers[i], **kwargs)
323            ax.set_ylabel(r"$\alpha\beta\gamma$ (degree)")
324
325        elif what in ("alpha", "beta", "gamma"):
326            i = ("alpha", "beta", "gamma").index(what)
327            marker = kwargs.pop("marker", None)
328            if marker is None:
329                marker = {"alpha": "o", "beta": "^", "gamma": "v"}[what]
330
331            label = kwargs.pop("label", what)
332            ax.plot(self.steps, [s.lattice.angles[i] for s in self.structures], label=label,
333                    marker=marker, **kwargs)
334            ax.set_ylabel(r"$\%s$ (degree)" % what)
335
336        elif what == "volume":
337            marker = kwargs.pop("marker", "o")
338            ax.plot(self.steps, [s.lattice.volume for s in self.structures], marker=marker, **kwargs)
339            ax.set_ylabel(r'$V\, (A^3)$')
340
341        elif what == "pressure":
342            stress_cart_tensors, pressures = self.reader.read_cart_stress_tensors()
343            marker = kwargs.pop("marker", "o")
344            label = kwargs.pop("label", "P")
345            ax.plot(self.steps, pressures, label=label, marker=marker, **kwargs)
346            ax.set_ylabel('P (GPa)')
347
348        elif what == "forces":
349            forces_hist = self.reader.read_cart_forces()
350            fmin_steps, fmax_steps, fmean_steps, fstd_steps = [], [], [], []
351            for step in range(self.num_steps):
352                forces = forces_hist[step]
353                fmods = np.sqrt([np.dot(force, force) for force in forces])
354                fmean_steps.append(fmods.mean())
355                fstd_steps.append(fmods.std())
356                fmin_steps.append(fmods.min())
357                fmax_steps.append(fmods.max())
358
359            mark = kwargs.pop("marker", None)
360            markers = ["o", "^", "v", "X"] if mark is None else 4 * [mark]
361            ax.plot(self.steps, fmin_steps, label="min |F|", marker=markers[0], **kwargs)
362            ax.plot(self.steps, fmax_steps, label="max |F|", marker=markers[1], **kwargs)
363            ax.plot(self.steps, fmean_steps, label="mean |F|", marker=markers[2], **kwargs)
364            ax.plot(self.steps, fstd_steps, label="std |F|", marker=markers[3], **kwargs)
365            label = "std |F"
366            ax.set_ylabel('F stats (eV/A)')
367
368        else:
369            raise ValueError("Invalid value for what: `%s`" % str(what))
370
371        ax.set_xlabel('Step')
372        ax.grid(True)
373        if label is not None:
374            ax.legend(loc='best', fontsize=fontsize, shadow=True)
375
376    @add_fig_kwargs
377    def plot(self, what_list=None, ax_list=None, fontsize=8, **kwargs):
378        """
379        Plot the evolution of structural parameters (lattice lengths, angles and volume)
380        as well as pressure, info on forces and total energy.
381
382        Args:
383            what_list:
384            ax_list: List of |matplotlib-Axes|. If None, a new figure is created.
385            fontsize: fontsize for legend
386
387        Returns: |matplotlib-Figure|
388        """
389        if what_list is None:
390            what_list = ["abc", "angles", "volume", "pressure", "forces", "energy"]
391        else:
392            what_list = list_strings(what_list)
393
394        nplots = len(what_list)
395        nrows, ncols = 1, 1
396        if nplots > 1:
397            ncols = 2
398            nrows = nplots // ncols + nplots % ncols
399
400        ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
401                                                sharex=True, sharey=False, squeeze=False)
402        ax_list = ax_list.ravel()
403        assert len(ax_list) == len(what_list)
404
405        # don't show the last ax if nplots is odd.
406        if nplots % ncols != 0: ax_list[-1].axis("off")
407
408        for what, ax in zip(what_list, ax_list):
409            self.plot_ax(ax, what, fontsize=fontsize, marker="o")
410
411        return fig
412
413    @add_fig_kwargs
414    def plot_energies(self, ax=None, fontsize=8, **kwargs):
415        """
416        Plot the total energies as function of the iteration step.
417
418        Args:
419            ax: |matplotlib-Axes| or None if a new figure should be created.
420            fontsize: Legend and title fontsize.
421
422        Returns: |matplotlib-Figure|
423        """
424        # TODO max force and pressure
425        ax, fig, plt = get_ax_fig_plt(ax=ax)
426
427        terms = self.reader.read_eterms()
428        for key, values in terms.items():
429            if np.all(values == 0.0): continue
430            ax.plot(self.steps, values, marker="o", label=key)
431
432        ax.set_xlabel('Step')
433        ax.set_ylabel('Energies (eV)')
434        ax.grid(True)
435        ax.legend(loc='best', fontsize=fontsize, shadow=True)
436
437        return fig
438
439    def yield_figs(self, **kwargs):  # pragma: no cover
440        """
441        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
442        """
443        yield self.plot(show=False)
444        yield self.plot_energies(show=False)
445
446    def mvplot_trajectories(self, colormap="hot", sampling=1, figure=None, show=True,
447                            with_forces=True, **kwargs):  # pragma: no cover
448        """
449        Call mayavi_ to plot atomic trajectories and the variation of the unit cell.
450        """
451        from abipy.display import mvtk
452        figure, mlab = mvtk.get_fig_mlab(figure=figure)
453        style = "labels"
454        line_width = 100
455        mvtk.plot_structure(self.initial_structure, style=style, unit_cell_color=(1, 0, 0), figure=figure)
456        mvtk.plot_structure(self.final_structure, style=style, unit_cell_color=(0, 0, 0), figure=figure)
457
458        steps = np.arange(start=0, stop=self.num_steps, step=sampling)
459        xcart_list = self.reader.read_value("xcart") * units.bohr_to_ang
460        for iatom in range(self.reader.natom):
461            x, y, z = xcart_list[::sampling, iatom, :].T
462            #for i in zip(x, y, z): print(i)
463            trajectory = mlab.plot3d(x, y, z, steps, colormap=colormap, tube_radius=None,
464                                    line_width=line_width, figure=figure)
465            mlab.colorbar(trajectory, title='Iteration', orientation='vertical')
466
467        if with_forces:
468            fcart_list = self.reader.read_cart_forces(unit="eV ang^-1")
469            for iatom in range(self.reader.natom):
470                x, y, z = xcart_list[::sampling, iatom, :].T
471                u, v, w = fcart_list[::sampling, iatom, :].T
472                q = mlab.quiver3d(x, y, z, u, v, w, figure=figure, colormap=colormap,
473                                  line_width=line_width, scale_factor=10)
474                #mlab.colorbar(q, title='Forces [eV/Ang]', orientation='vertical')
475
476        if show: mlab.show()
477        return figure
478
479    def mvanimate(self, delay=500):  # pragma: no cover
480        from abipy.display import mvtk
481        figure, mlab = mvtk.get_fig_mlab(figure=None)
482        style = "points"
483        #mvtk.plot_structure(self.initial_structure, style=style, figure=figure)
484        #mvtk.plot_structure(self.final_structure, style=style, figure=figure)
485
486        xcart_list = self.reader.read_value("xcart") * units.bohr_to_ang
487        #t = np.arange(self.num_steps)
488        #line_width = 2
489        #for iatom in range(self.reader.natom):
490        #    x, y, z = xcart_list[:, iatom, :].T
491        #    trajectory = mlab.plot3d(x, y, z, t, colormap=colormap, tube_radius=None, line_width=line_width, figure=figure)
492        #mlab.colorbar(trajectory, title='Iteration', orientation='vertical')
493
494        #x, y, z = xcart_list[0, :, :].T
495        #nodes = mlab.points3d(x, y, z)
496        #nodes.glyph.scale_mode = 'scale_by_vector'
497        #this sets the vectors to be a 3x5000 vector showing some random scalars
498        #nodes.mlab_source.dataset.point_data.vectors = np.tile( np.random.random((5000,)), (3,1))
499        #nodes.mlab_source.dataset.point_data.scalars = np.random.random((5000,))
500
501        @mlab.show
502        @mlab.animate(delay=delay, ui=True)
503        def anim():
504            """Animate."""
505            #for it in range(self.num_steps):
506            for it, structure in enumerate(self.structures):
507                print('Updating scene for iteration:', it)
508                #mlab.clf(figure=figure)
509                mvtk.plot_structure(structure, style=style, figure=figure)
510                #x, y, z = xcart_list[it, :, :].T
511                #nodes.mlab_source.set(x=x, y=y, z=z)
512                #figure.scene.render()
513                mlab.draw(figure=figure)
514                yield
515
516        anim()
517
518    def get_panel(self):
519        """Build panel with widgets to interact with the |HistFile| either in a notebook or in panel app."""
520        from abipy.panels.hist import HistFilePanel
521        return HistFilePanel(self).get_panel()
522
523    def write_notebook(self, nbpath=None):
524        """
525        Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current
526        working directory is created. Return path to the notebook.
527        """
528        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
529
530        nb.cells.extend([
531            #nbv.new_markdown_cell("# This is a markdown cell"),
532            nbv.new_code_cell("hist = abilab.abiopen('%s')" % self.filepath),
533            nbv.new_code_cell("print(hist)"),
534            nbv.new_code_cell("hist.plot_energies();"),
535            nbv.new_code_cell("hist.plot();"),
536        ])
537
538        return self._write_nb_nbpath(nb, nbpath)
539
540
541class HistRobot(Robot):
542    """
543    This robot analyzes the results contained in multiple HIST.nc_ files.
544
545    .. rubric:: Inheritance Diagram
546    .. inheritance-diagram:: HistRobot
547    """
548    EXT = "HIST"
549
550    def to_string(self, verbose=0):
551        """String representation with verbosity level ``verbose``."""
552        s = ""
553        if verbose:
554            s = super().to_string(verbose=0)
555        df = self.get_dataframe()
556        s_df = "Table with final structures, pressures in GPa and force stats in eV/Ang:\n\n%s" % str(df)
557        if s:
558            return "\n".join([s, str(s_df)])
559        else:
560            return str(s_df)
561
562    def get_dataframe(self, with_geo=True, index=None, abspath=False, with_spglib=True, funcs=None, **kwargs):
563        """
564        Return a |pandas-DataFrame| with the most important final results and the filenames as index.
565
566        Args:
567            with_geo: True if structure info should be added to the dataframe
568            abspath: True if paths in index should be absolute. Default: Relative to getcwd().
569            index: Index of the dataframe, if None, robot labels are used
570            with_spglib: If True, spglib_ is invoked to get the space group symbol and number
571
572        kwargs:
573            attrs:
574                List of additional attributes of the |GsrFile| to add to the |pandas-DataFrame|.
575            funcs: Function or list of functions to execute to add more data to the DataFrame.
576                Each function receives a |GsrFile| object and returns a tuple (key, value)
577                where key is a string with the name of column and value is the value to be inserted.
578        """
579        # Add attributes specified by the users
580        attrs = [
581            "num_steps", "final_energy", "final_pressure",
582            "final_fmin", "final_fmax", "final_fmean", "final_fstd", "final_drift",
583            "initial_fmin", "initial_fmax", "initial_fmean", "initial_fstd", "initial_drift",
584            # TODO add more columns but must update HIST file
585            #"nsppol", "nspinor", "nspden",
586            #"ecut", "pawecutdg", "tsmear", "nkpt",
587        ] + kwargs.pop("attrs", [])
588
589        rows, row_names = [], []
590        for label, hist in self.items():
591            row_names.append(label)
592            d = OrderedDict()
593
594            initial_fstas_dict = hist.get_fstats_dict(step=0)
595            final_fstas_dict = hist.get_fstats_dict(step=-1)
596
597            # Add info on structure.
598            if with_geo:
599                d.update(hist.final_structure.get_dict4pandas(with_spglib=with_spglib))
600
601            for aname in attrs:
602                if aname in ("final_fmin", "final_fmax", "final_fmean", "final_fstd", "final_drift",):
603                    value = final_fstas_dict[aname.replace("final_", "")]
604                elif aname in ("initial_fmin", "initial_fmax", "initial_fmean", "initial_fstd", "initial_drift"):
605                    value = initial_fstas_dict[aname.replace("initial_", "")]
606                else:
607                    value = getattr(hist, aname, None)
608                d[aname] = value
609
610            # Execute functions
611            if funcs is not None: d.update(self._exec_funcs(funcs, hist))
612            rows.append(d)
613
614        import pandas as pd
615        row_names = row_names if not abspath else self._to_relpaths(row_names)
616        index = row_names if index is None else index
617        return pd.DataFrame(rows, index=index, columns=list(rows[0].keys()))
618
619    @property
620    def what_list(self):
621        """List with all quantities that can be plotted (what_list)."""
622        return ["energy", "abc", "angles", "volume", "pressure", "forces"]
623
624    @add_fig_kwargs
625    def gridplot(self, what_list=None, sharex="row", sharey="row", fontsize=8, **kwargs):
626        """
627        Plot the ``what`` value extracted from multiple HIST.nc_ files on a grid.
628
629        Args:
630            what_list: List of quantities to plot.
631                Must be in ["energy", "abc", "angles", "volume", "pressure", "forces"]
632            sharex: True if xaxis should be shared.
633            sharey: True if yaxis should be shared.
634            fontsize: fontsize for legend.
635
636        Returns: |matplotlib-Figure|
637        """
638        what_list = list_strings(what_list) if what_list is not None else self.what_list
639
640        # Build grid of plots.
641        nrows, ncols = len(what_list), len(self)
642
643        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
644                                               sharex=sharex, sharey=sharey, squeeze=False)
645        ax_mat = np.reshape(ax_mat, (nrows, ncols))
646
647        for irow, what in enumerate(what_list):
648            for icol, hist in enumerate(self.abifiles):
649                ax = ax_mat[irow, icol]
650                ax.grid(True)
651                hist.plot_ax(ax_mat[irow, icol], what, fontsize=fontsize, marker="o")
652
653                if irow == 0:
654                    ax.set_title(hist.relpath, fontsize=fontsize)
655                if irow != nrows - 1:
656                    set_visible(ax, False, "xlabel")
657                if icol != 0:
658                    set_visible(ax, False, "ylabel")
659
660        return fig
661
662    @add_fig_kwargs
663    def combiplot(self, what_list=None, colormap="jet", fontsize=6, **kwargs):
664        """
665        Plot multiple HIST.nc_ files on a grid. One plot for each ``what`` value.
666
667        Args:
668            what_list: List of strings with the quantities to plot. If None, all quanties are plotted.
669            colormap: matplotlib color map.
670            fontsize: fontisize for legend.
671
672        Returns: |matplotlib-Figure|.
673        """
674        what_list = (list_strings(what_list) if what_list is not None
675            else ["energy", "a", "b", "c", "alpha", "beta", "gamma", "volume", "pressure"])
676
677        num_plots, ncols, nrows = len(what_list), 1, 1
678        if num_plots > 1:
679            ncols = 2
680            nrows = (num_plots // ncols) + (num_plots % ncols)
681
682        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
683                                                sharex=True, sharey=False, squeeze=False)
684        ax_list = ax_list.ravel()
685        cmap = plt.get_cmap(colormap)
686
687        for i, (ax, what) in enumerate(zip(ax_list, what_list)):
688            for ih, hist in enumerate(self.abifiles):
689                label = None if i != 0 else hist.relpath
690                hist.plot_ax(ax, what, color=cmap(ih / len(self)), label=label, fontsize=fontsize)
691
692            if label is not None:
693                ax.legend(loc="best", fontsize=fontsize, shadow=True)
694
695            if i == len(ax_list) - 1:
696                ax.set_xlabel("Step")
697            else:
698                ax.set_xlabel("")
699
700        # Get around a bug in matplotlib.
701        if num_plots % ncols != 0: ax_list[-1].axis('off')
702
703        return fig
704
705    def yield_figs(self, **kwargs):  # pragma: no cover
706        """
707        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
708        """
709        yield self.gridplot(show=False)
710        yield self.combiplot(show=False)
711
712    def write_notebook(self, nbpath=None):
713        """
714        Write a jupyter_ notebook to nbpath. If nbpath is None, a temporay file in the current
715        working directory is created. Return path to the notebook.
716        """
717        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
718
719        args = [(l, f.filepath) for l, f in self.items()]
720        nb.cells.extend([
721            #nbv.new_markdown_cell("# This is a markdown cell"),
722            nbv.new_code_cell("robot = abilab.HistRobot(*%s)\nrobot.trim_paths()\nrobot" % str(args)),
723            nbv.new_code_cell("robot.get_dataframe()"),
724            nbv.new_code_cell("for what in robot.what_list: robot.gridplot(what=what, tight_layout=True);"),
725        ])
726
727        # Mixins
728        #nb.cells.extend(self.get_baserobot_code_cells())
729
730        return self._write_nb_nbpath(nb, nbpath)
731
732
733class HistReader(ETSF_Reader):
734    """
735    This object reads data from the HIST file.
736
737
738    .. rubric:: Inheritance Diagram
739    .. inheritance-diagram:: HistReader
740    """
741
742    @lazy_property
743    def num_steps(self):
744        """Number of iterations present in the HIST.nc_ file."""
745        return self.read_dimvalue("time")
746
747    @lazy_property
748    def natom(self):
749        """Number of atoms un the unit cell."""
750        return self.read_dimvalue("natom")
751
752    def read_all_structures(self):
753        """Return the list of structures at the different iteration steps."""
754        rprimd_list = self.read_value("rprimd")
755        xred_list = self.read_value("xred")
756
757        # Alchemical mixing is not supported.
758        num_pseudos = self.read_dimvalue("npsp")
759        ntypat = self.read_dimvalue("ntypat")
760        if num_pseudos != ntypat:
761            raise NotImplementedError("Alchemical mixing is not supported, num_pseudos != ntypat")
762
763        znucl, typat = self.read_value("znucl"), self.read_value("typat").astype(int)
764        #print(znucl.dtype, typat)
765        cart_forces_step = self.read_cart_forces(unit="eV ang^-1")
766
767        structures = []
768        #print("typat", type(typat))
769        for step in range(self.num_steps):
770            s = Structure.from_abivars(
771                xred=xred_list[step],
772                rprim=rprimd_list[step],
773                acell=3 * [1.0],
774                # FIXME ntypat, typat, znucl are missing!
775                znucl=znucl,
776                typat=typat,
777            )
778            s.add_site_property("cartesian_forces", cart_forces_step[step])
779            structures.append(s)
780
781        return structures
782
783    def read_eterms(self, unit="eV"):
784        """|AttrDict| with the decomposition of the total energy in units ``unit``"""
785        return AttrDict(
786            etotals=units.EnergyArray(self.read_value("etotal"), "Ha").to(unit),
787            kinetic_terms=units.EnergyArray(self.read_value("ekin"), "Ha").to(unit),
788            entropies=units.EnergyArray(self.read_value("entropy"), "Ha").to(unit),
789        )
790
791    def read_cart_forces(self, unit="eV ang^-1"):
792        """
793        Read and return a |numpy-array| with the cartesian forces in unit ``unit``.
794        Shape (num_steps, natom, 3)
795        """
796        return units.ArrayWithUnit(self.read_value("fcart"), "Ha bohr^-1").to(unit)
797
798    def read_reduced_forces(self):
799        """
800        Read and return a |numpy-array| with the forces in reduced coordinates
801        Shape (num_steps, natom, 3)
802        """
803        return self.read_value("fred")
804
805    def read_cart_stress_tensors(self):
806        """
807        Return the stress tensors (nstep x 3 x 3) in cartesian coordinates (GPa)
808        and the list of pressures in GPa unit.
809        """
810        # Abinit stores 6 unique components of this symmetric 3x3 tensor:
811        # Given in order (1,1), (2,2), (3,3), (3,2), (3,1), (2,1).
812        c = self.read_value("strten")
813        tensors = np.empty((self.num_steps, 3, 3), dtype=float)
814
815        for step in range(self.num_steps):
816            for i in range(3): tensors[step, i,i] = c[step, i]
817            for p, (i, j) in enumerate(((2,1), (2,0), (1,0))):
818                tensors[step, i,j] = c[step, 3+p]
819                tensors[step, j,i] = c[step, 3+p]
820
821        tensors *= abu.HaBohr3_GPa
822        pressures = np.empty(self.num_steps)
823        for step, tensor in enumerate(tensors):
824            pressures[step] = - tensor.trace() / 3
825
826        return tensors, pressures
827