1# coding: utf-8
2"""
3This module contains objects for postprocessing A2F calculations (phonon lifetimes in metals
4and Eliashberg function).
5
6Warning:
7    Work in progress, DO NOT USE THIS CODE.
8"""
9import numpy as np
10import pymatgen.core.units as units
11import abipy.core.abinit_units as abu
12
13from collections import OrderedDict
14from scipy.integrate import cumtrapz, simps
15from monty.string import marquee, list_strings
16from monty.functools import lazy_property
17from abipy.core.mixins import AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter
18from abipy.core.kpoints import Kpath
19from abipy.tools.plotting import (add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims, set_visible,
20                                  rotate_ticklabels)
21from abipy.tools import duck
22from abipy.electrons.ebands import ElectronDos, RobotWithEbands
23from abipy.dfpt.phonons import PhononBands, PhononDos, RobotWithPhbands
24from abipy.abio.robots import Robot
25from abipy.eph.common import BaseEphReader
26
27
28_LATEX_LABELS = {
29    "lambda_iso": r"$\lambda_{iso}$",
30    "omega_log": r"$\omega_{log}$",
31    "a2f": r"$\alpha^2F(\omega)$",
32    "lambda": r"$\lambda(\omega)$",
33}
34
35
36class A2f(object):
37    """
38    Eliashberg function a2F(w). Energies are in eV.
39    """
40    # Markers used for up/down bands.
41    marker_spin = {0: "^", 1: "v"}
42
43    def __init__(self, mesh, values_spin, values_spin_nu, ngqpt, meta):
44        """
45        Args:
46            mesh: Energy mesh in eV
47            values(nomega,0:natom3,nsppol)
48            vals(w,1:natom,1:nsppol): a2f(w) decomposed per phonon branch and spin
49            vals(w,0,1:nsppol): a2f(w) summed over phonons modes, decomposed in spin
50            ngqpt: Q-mesh used to compute A2f.
51            meta: Dictionary with metavariables.
52
53        TODO:
54            1. possibility of computing a2f directly from data on file?
55        """
56        self.mesh = mesh
57        self.ngqpt = ngqpt
58        self.meta = meta
59
60        # Spin dependent and total a2F(w)
61        values_spin = np.atleast_2d(values_spin)
62        values_spin_nu = np.atleast_3d(values_spin_nu)
63        self.nsppol = len(values_spin)
64        self.nmodes = values_spin_nu.shape[1]
65        assert self.nmodes % 3 == 0
66        self.natom = self.nmodes // 3
67
68        if self.nsppol == 2:
69            self.values = values_spin[0] + values_spin[1]
70            self.values_nu = values_spin_nu[0] + values_spin_nu[1]
71        elif self.nsppol == 1:
72            self.values = values_spin[0]
73            self.values_nu = values_spin_nu[0]
74        else:
75            raise ValueError("Invalid nsppol: %s" % self.nsppol)
76
77        self.values_spin = values_spin
78        self.values_spin_nu = values_spin_nu
79        #self.lambdaw ?
80
81    @lazy_property
82    def iw0(self):
83        """
84        Index of the first point in the mesh whose value is >= 0
85        Integrals are performed with wmesh[iw0 + 1, :] i.e. unstable modes are neglected.
86        """
87        for i, x in enumerate(self.mesh):
88            if x >= 0.0: return i
89        else:
90            raise ValueError("Cannot find zero in energy mesh")
91
92    def __str__(self):
93        return self.to_string()
94
95    def to_string(self, title=None, verbose=0):
96        """
97        String representation with verbosity level ``verbose`` and an optional ``title``.
98        """
99        lines = []; app = lines.append
100
101        app("Eliashberg Function" if not title else str(title))
102        # TODO: Add ElectronDos
103        #app("Isotropic lambda: %.3f" % (self.lambda_iso))
104        app("Isotropic lambda: %.2f, omega_log: %.3f (eV), %.3f (K)" % (
105            self.lambda_iso, self.omega_log, self.omega_log * abu.eV_to_K))
106        app("Q-mesh: %s" % str(self.ngqpt))
107        app("Mesh from %.4f to %.4f (eV) with %d points" % (
108            self.mesh[0], self.mesh[-1], len(self.mesh)))
109
110        if verbose:
111            for mustar in (0.1, 0.12, 0.2):
112                app("\tFor mustar %s: McMillan Tc: %s [K]" % (mustar, self.get_mcmillan_tc(mustar)))
113
114        if verbose > 1:
115            # $\int dw [a2F(w)/w] w^n$
116            for n in [0, 4]:
117                app("Moment %s: %s" % (n, self.get_moment(n)))
118
119            app("Meta: %s" % str(self.meta))
120
121        return "\n".join(lines)
122
123    @lazy_property
124    def lambda_iso(self):
125        """Isotropic lambda."""
126        return self.get_moment(n=0)
127
128    @lazy_property
129    def omega_log(self):
130        r"""
131        Logarithmic moment of alpha^2F: exp((2/\lambda) \int dw a2F(w) ln(w)/w)
132        """
133        iw = self.iw0 + 1
134        wmesh, a2fw = self.mesh[iw:], self.values[iw:]
135
136        fw = a2fw / wmesh * np.log(wmesh)
137        integral = simps(fw, x=wmesh)
138
139        return np.exp(1.0 / self.lambda_iso * integral)
140
141    def get_moment(self, n, spin=None, cumulative=False):
142        r"""
143        Computes the moment of a2F(w) i.e. $\int dw [a2F(w)/w] w^n$
144        From Allen PRL 59 1460 (See also Grimvall, Eq 6.72 page 175)
145        """
146        wmesh = self.mesh[self.iw0+1:]
147        if spin is None:
148            a2fw = self.values[self.iw0+1:]
149        else:
150            a2fw = self.values_spin[spin][self.iw0+1:]
151
152        # Primitive is given on the same mesh as self.
153        ff = a2fw * (wmesh ** (n - 1))
154        vals = np.zeros(self.mesh.shape)
155        vals[self.iw0+1:] = cumtrapz(ff, x=wmesh, initial=0.0)
156
157        return vals if cumulative else vals[-1].copy()
158
159    def get_moment_nu(self, n, nu, spin=None, cumulative=False):
160        r"""
161        Computes the moment of a2F(w) i.e. $\int dw [a2F(w)/w] w^n$
162        From Allen PRL 59 1460 (See also Grimvall, Eq 6.72 page 175)
163        """
164        wmesh = self.mesh[self.iw0+1:]
165        if spin is None:
166            a2fw = self.values_nu[nu][self.iw0+1:]
167        else:
168            a2fw = self.values_spin_nu[spin][nu][self.iw0+1:]
169
170        # Primitive is given on the same mesh as self.
171        ff = a2fw * (wmesh ** (n - 1))
172        vals = np.zeros(self.mesh.shape)
173        vals[self.iw0+1:] = cumtrapz(ff, x=wmesh, initial=0.0)
174
175        return vals if cumulative else vals[-1].copy()
176
177    def get_mcmillan_tc(self, mustar):
178        """
179        Computes the critical temperature with the McMillan equation and the input mustar.
180
181        Return: Tc in Kelvin.
182        """
183        tc = (self.omega_log / 1.2) * \
184            np.exp(-1.04 * (1.0 + self.lambda_iso) / (self.lambda_iso - mustar * (1.0 + 0.62 * self.lambda_iso)))
185
186        return tc * abu.eV_to_K
187
188    def get_mustar_from_tc(self, tc):
189        """
190        Return the value of mustar that gives the critical temperature ``tc`` in the McMillan equation.
191
192        Args:
193            tc: Critical temperature in Kelvin.
194        """
195        l = self.lambda_iso
196        num = l + (1.04 * (1 + l) / np.log(1.2 * abu.kb_eVK * tc / self.omega_log))
197
198        return num / (1 + 0.62 * l)
199
200    @add_fig_kwargs
201    def plot(self, what="a2f", units="eV", exchange_xy=False, ax=None,
202             xlims=None, ylims=None, label=None, fontsize=12, **kwargs):
203        """
204        Plot a2F(w) or lambda(w) depending on the value of `what`.
205
206        Args:
207            what: a2f for a2F(w), lambda for lambda(w)
208            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz"). Case-insensitive.
209            exchange_xy: True to exchange x-y axes.
210            ax: |matplotlib-Axes| or None if a new figure should be created.
211            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
212                or scalar e.g. ``left``. If left (right) is None, default values are used
213            ylims: Limits for y-axis. See xlims for API.
214            label: True to add legend label to each curve.
215            fontsize: Legend and title fontsize
216            kwargs: linestyle, color, linewidth passed to ax.plot.
217
218        Returns: |matplotlib-Figure|
219        """""
220        ax, fig, plt = get_ax_fig_plt(ax=ax)
221        wfactor = abu.phfactor_ev2units(units)
222        ylabel = _LATEX_LABELS[what]
223
224        style = dict(
225            linestyle=kwargs.pop("linestyle", "-"),
226            color=kwargs.pop("color", "k"),
227            linewidth=kwargs.pop("linewidth", 1),
228        )
229
230        # Plot a2f(w)
231        if what == "a2f":
232            xx, yy = self.mesh * wfactor, self.values
233            if exchange_xy: xx, yy = yy, xx
234            ax.plot(xx, yy, label=label, **style)
235
236            if self.nsppol == 2:
237                # Plot spin resolved a2f(w).
238                for spin in range(self.nsppol):
239                    xx, yy = self.mesh * wfactor, self.values_spin[spin]
240                    if exchange_xy: xx, yy = yy, xx
241                    ax.plot(xx, yy, marker=self.marker_spin[spin], **style)
242
243        # Plot lambda(w)
244        elif what == "lambda":
245            lambda_w = self.get_moment(n=0, cumulative=True)
246            xx, yy = self.mesh * wfactor, lambda_w
247            if exchange_xy: xx, yy = yy, xx
248            ax.plot(xx, yy, label=label, **style)
249
250        else:
251            raise ValueError("Invalid value for what: `%s`" % str(what))
252
253        xlabel = abu.wlabel_from_units(units)
254        if exchange_xy: xlabel, ylabel = ylabel, xlabel
255
256        ax.set_xlabel(xlabel)
257        ax.set_ylabel(ylabel)
258        ax.grid(True)
259        set_axlims(ax, xlims, "x")
260        set_axlims(ax, ylims, "y")
261        if label: ax.legend(loc="best", shadow=True, fontsize=fontsize)
262
263        return fig
264
265    @add_fig_kwargs
266    def plot_with_lambda(self, units="eV", ax=None, xlims=None, fontsize=12, **kwargs):
267        """
268        Plot a2F(w) and lambda(w) on the same figure.
269
270        Args:
271            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz"). Case-insensitive.
272            ax: |matplotlib-Axes| or None if a new figure should be created.
273            xlims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
274                or scalar e.g. ``left``. If left (right) is None, default values are used
275            fontsize: Legend and title fontsize
276
277        Returns: |matplotlib-Figure|
278        """""
279        ax, fig, plt = get_ax_fig_plt(ax=ax)
280        for i, what in enumerate(["a2f", "lambda"]):
281            this_ax = ax if i == 0 else ax.twinx()
282            self.plot(what=what, ax=this_ax, units=units, fontsize=fontsize, xlims=xlims, show=False, **kwargs)
283            if i > 0:
284                this_ax.yaxis.set_label_position("right")
285                this_ax.grid(False)
286
287        return fig
288
289    @add_fig_kwargs
290    def plot_nuterms(self, units="eV", ax_mat=None, with_lambda=True, fontsize=12,
291                     xlims=None, ylims=None, label=None, **kwargs):
292        """
293        Plot a2F(w), lambda(w) and optionally the individual contributions due to the phonon branches.
294
295        Args:
296            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
297                Case-insensitive.
298            ax_mat: Matrix of axis of shape [natom, 3]. None if a new figure should be created.
299            fontsize: Legend and title fontsize.
300            xlims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
301                or scalar e.g. ``left``. If left (right) is None, default values are used
302            ylims: Limits for y-axis. See xlims for API.
303            label: True to add legend label to each curve.
304
305        Returns: |matplotlib-Figure|
306        """""
307        # Get ax_mat and fig.
308        nrows, ncols = self.natom, 3
309        ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols,
310                                               sharex=True, sharey=True, squeeze=False)
311        ax_mat = np.reshape(ax_mat, (self.natom, 3))
312
313        wfactor = abu.phfactor_ev2units(units)
314        wvals = self.mesh * wfactor
315
316        if with_lambda:
317            lax_nu = [ax.twinx() for ax in ax_mat.flat]
318            # Share axis after creation. Based on
319            # https://stackoverflow.com/questions/42973223/how-share-x-axis-of-two-subplots-after-they-are-created
320            lax_nu[0].get_shared_x_axes().join(*lax_nu)
321            lax_nu[0].get_shared_y_axes().join(*lax_nu)
322            for i, ax in enumerate(lax_nu):
323                if i == 2: continue
324                ax.set_yticklabels([])
325                #ax.set_xticklabels([])
326
327        # TODO Better handling of styles
328        a2f_style = dict(
329            linestyle=kwargs.pop("linestyle", "-"),
330            color=kwargs.pop("color", "k"),
331            linewidth=kwargs.pop("linewidth", 1),
332        )
333        lambda_style = a2f_style.copy()
334        lambda_style["color"] = "red"
335
336        import itertools
337        for idir, iatom in itertools.product(range(3), range(self.natom)):
338            nu = idir + 3 * iatom
339            ax = ax_mat[iatom, idir]
340            ax.grid(True)
341            ax.set_title(r"$\nu = %d$" % nu, fontsize=fontsize)
342            if idir == 0:
343                ax.set_ylabel(r"$\alpha^2F(\omega)$")
344            else:
345                pass
346                # Turn off tick labels
347                #ax.set_yticklabels([])
348                #ax.set_yticks([])
349
350            if iatom == self.natom - 1:
351                ax.set_xlabel(abu.wlabel_from_units(units))
352            #set_axlims(ax, xlims, "x")
353            #set_axlims(ax, ylims, "y")
354
355            # Plot total a2f(w)
356            ax.plot(wvals, self.values_nu[nu], **a2f_style)
357
358            # Plot lambda(w)
359            if with_lambda:
360                lambdaw_nu = self.get_moment_nu(n=0, nu=nu, cumulative=True)
361                lax = lax_nu[nu]
362                lax.plot(wvals, lambdaw_nu, **lambda_style)
363                if idir == 2:
364                    lax.set_ylabel(r"$\lambda_{\nu}(\omega)$", color=lambda_style["color"])
365
366            #if self.nsppol == 2:
367            #   # Plot spin resolved a2f(w)
368            #   for spin in range(self.nsppol):
369            #       ax.plot(self.mesh, self.values_spin_nu[spin, nu],
370            #               marker=self.marker_spin[spin], **a2f_style)
371
372        return fig
373
374    @add_fig_kwargs
375    def plot_a2(self, phdos, atol=1e-12, **kwargs):
376        """
377        Grid with 3 plots showing: a2F(w), F(w), a2F(w). Requires phonon DOS.
378
379        Args:
380            phdos: |PhononDos|
381            atol: F(w) is replaced by atol in a2F(w) / F(w) ratio where :math:`|F(w)|` < atol
382
383        Returns: |matplotlib-Figure|
384        """
385        phdos = PhononDos.as_phdos(phdos)
386
387        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=3, ncols=1,
388                                                sharex=True, sharey=False, squeeze=True)
389        ax_list = ax_list.ravel()
390
391        # Spline phdos onto a2f mesh and compute a2F(w) / F(w)
392        f = phdos.spline(self.mesh)
393        f = self.values / np.where(np.abs(f) > atol, f, atol)
394        ax = ax_list[0]
395        ax.plot(self.mesh, f, color="k", linestyle="-")
396        ax.grid(True)
397        ax.set_ylabel(r"$\alpha^2(\omega)$ [1/eV]")
398
399        # Plot F(w). TODO: This should not be called plot_dos_idos!
400        ax = ax_list[1]
401        phdos.plot_dos_idos(ax=ax, what="d", color="k", linestyle="-")
402        ax.grid(True)
403        ax.set_ylabel(r"$F(\omega)$ [states/eV]")
404
405        # Plot a2f
406        self.plot(ax=ax_list[2], color="k", linestyle="-", linewidths=2, show=False)
407
408        return fig
409
410    @add_fig_kwargs
411    def plot_tc_vs_mustar(self, start=0.1, stop=0.3, num=50, ax=None, **kwargs):
412        """
413        Plot Tc(mustar)
414
415        Args:
416            start: The starting value of the sequence.
417            stop: The end value of the sequence
418            num (int): optional. Number of samples to generate. Default is 50. Must be non-negative.
419            ax: |matplotlib-Axes| or None if a new figure should be created.
420
421        Returns: |matplotlib-Figure|
422        """
423        # TODO start and stop to avoid singularity in Mc Tc
424        mustar_values = np.linspace(start, stop, num=num)
425        tc_vals = [self.get_mcmillan_tc(mustar) for mustar in mustar_values]
426
427        ax, fig, plt = get_ax_fig_plt(ax=ax)
428        ax.plot(mustar_values, tc_vals, **kwargs)
429        ax.set_yscale("log")
430        ax.grid(True)
431        ax.set_xlabel(r"$\mu^*$")
432        ax.set_ylabel(r"$T_c$ [K]")
433
434        return fig
435
436
437class A2Ftr(object):
438    """
439    Transport Eliashberg function a2F(w). Energies are in eV.
440    """
441    # Markers used for up/down bands (collinear spin)
442    marker_spin = {0: "^", 1: "v"}
443
444    def __init__(self, mesh, vals_in, vals_out):
445        """
446        Args:
447            mesh: Energy mesh in eV
448            vals_in(nomega,3,3,0:natom3,nsppol):
449                Eliashberg transport functions for in and out scattering
450            vals_in(w,3,3,1:natom3,1:nsppol): a2f_tr(w) decomposed per phonon branch and spin
451            vals_in(w,3,3,0,1:nsppol): a2f_tr(w) summed over phonons modes, decomposed in spin
452        """
453        self.mesh = mesh
454
455    @lazy_property
456    def iw0(self):
457        """
458        Index of the first point in the mesh whose value is >= 0
459        Integrals are performed with wmesh[iw0 + 1, :] i.e. unstable modes are neglected.
460        """
461        for i, x in enumerate(self.mesh):
462            if x >= 0.0: return i
463        else:
464            raise ValueError("Cannot find zero in energy mesh")
465
466
467class A2fFile(AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter):
468    """
469    This file contains the phonon linewidths, EliashbergFunction, the |PhononBands|,
470    the |ElectronBands| and |ElectronDos| on the k-mesh.
471    Provides methods to analyze and plot results.
472
473    Usage example:
474
475    .. code-block:: python
476
477        with A2fFile("out_A2F.nc") as ncfile:
478            print(ncfile)
479            ncfile.ebands.plot()
480            ncfile.phbands.plot()
481
482    .. rubric:: Inheritance Diagram
483    .. inheritance-diagram:: A2fFile
484    """
485    @classmethod
486    def from_file(cls, filepath):
487        """Initialize the object from a netcdf_ file."""
488        return cls(filepath)
489
490    def __init__(self, filepath):
491        super().__init__(filepath)
492        self.reader = A2fReader(filepath)
493
494    def __str__(self):
495        """String representation."""
496        return self.to_string()
497
498    def to_string(self, verbose=0):
499        """String representation."""
500        lines = []; app = lines.append
501
502        app(marquee("File Info", mark="="))
503        app(self.filestat(as_string=True))
504        app("")
505        app(self.structure.to_string(verbose=verbose, title="Structure"))
506        app("")
507        app(self.ebands.to_string(with_structure=False, verbose=verbose, title="Electronic Bands"))
508        app("")
509        app(self.phbands.to_string(with_structure=False, verbose=verbose, title="Phonon Bands"))
510        app("")
511        # E-PH section
512        app(marquee("E-PH calculation", mark="="))
513        app("K-mesh for electrons:")
514        app(self.ebands.kpoints.ksampling.to_string(verbose=verbose))
515        if verbose:
516            app("Has transport a2Ftr(w): %s" % self.has_a2ftr)
517        app("")
518        a2f = self.a2f_qcoarse
519        app("a2f(w) on the %s q-mesh (ddb_ngqpt|eph_ngqpt)" % str(a2f.ngqpt))
520        app("Isotropic lambda: %.2f, omega_log: %.3f (eV), %.3f (K)" % (
521            a2f.lambda_iso, a2f.omega_log, a2f.omega_log * abu.eV_to_K))
522        #app(self.a2f_qcoarse.to_string(title=title, verbose=verbose))
523        app("")
524        a2f = self.a2f_qintp
525        app("a2f(w) Fourier interpolated on the %s q-mesh (ph_ngqpt)" % str(a2f.ngqpt))
526        app("Isotropic lambda: %.2f, omega_log: %.3f (eV), %.3f (K)" % (
527            a2f.lambda_iso, a2f.omega_log, a2f.omega_log * abu.eV_to_K))
528        #app(self.a2f_qintp.to_string(title=title, verbose=verbose))
529
530        return "\n".join(lines)
531
532    @lazy_property
533    def ebands(self):
534        """|ElectronBands| object."""
535        return self.reader.read_ebands()
536
537    @lazy_property
538    def edos(self):
539        """|ElectronDos| object with e-DOS computed by Abinit."""
540        return self.reader.read_edos()
541
542    @property
543    def structure(self):
544        """|Structure| object."""
545        return self.ebands.structure
546
547    @property
548    def phbands(self):
549        """
550        |PhononBands| object with frequencies along the q-path.
551        Contains (interpolated) linewidths.
552        """
553        return self.reader.read_phbands_qpath()
554
555    @lazy_property
556    def params(self):
557        """:class:`OrderedDict` with parameters that might be subject to convergence studies."""
558        od = self.get_ebands_params()
559        # Add EPH parameters.
560        od.update(self.reader.common_eph_params)
561
562        return od
563
564    @lazy_property
565    def a2f_qcoarse(self):
566        """
567        :class:`A2f` with the Eliashberg function a2F(w) computed on the (coarse) ab-initio q-mesh.
568        """
569        return self.reader.read_a2f(qsamp="qcoarse")
570
571    @lazy_property
572    def a2f_qintp(self):
573        """
574        :class:`A2f` with the Eliashberg function a2F(w) computed on the dense q-mesh by Fourier interpolation.
575        """
576        return self.reader.read_a2f(qsamp="qintp")
577
578    def get_a2f_qsamp(self, qsamp):
579        """Return the :class:`A2f` object associated to q-sampling ``qsamp``."""
580        if qsamp == "qcoarse": return self.a2f_qcoarse
581        if qsamp == "qintp": return self.a2f_qintp
582        raise ValueError("Invalid value for qsamp `%s`" % str(qsamp))
583
584    @lazy_property
585    def has_a2ftr(self):
586        """True if the netcdf file contains transport data."""
587        return "a2ftr_qcoarse" in self.reader.rootgrp.variables
588
589    @lazy_property
590    def a2ftr_qcoarse(self):
591        """
592        :class:`A2ftr` with the Eliashberg transport spectral function a2F_tr(w, x, x')
593        computed on the (coarse) ab-initio q-mesh
594        """
595        if not self.has_a2ftr: return None
596        return self.reader.read_a2ftr(qsamp="qcoarse")
597
598    @lazy_property
599    def a2ftr_qintp(self):
600        """
601        :class:`A2ftr` with the Eliashberg transport spectral function a2F_tr(w, x, x')
602        computed on the dense q-mesh by Fourier interpolation.
603        """
604        if not self.has_a2ftr: return None
605        return self.reader.read_a2ftr(qsamp="qintp")
606
607    def get_a2ftr_qsamp(self, qsamp):
608        """Return the :class:`A2ftr` object associated to q-sampling ``qsamp``."""
609        if qsamp == "qcoarse": return self.a2ftr_qcoarse
610        if qsamp == "qintp": return self.a2ftr_qintp
611        raise ValueError("Invalid value for qsamp `%s`" % str(qsamp))
612
613    def close(self):
614        """Close the file."""
615        self.reader.close()
616
617    #def interpolate(self, ddb, lpratio=5, vertices_names=None, line_density=20, filter_params=None, verbose=0):
618    #    """
619    #    Interpolate the phonon linewidths on a k-path and, optionally, on a k-mesh.
620
621    #    Args:
622    #        lpratio: Ratio between the number of star functions and the number of ab-initio k-points.
623    #            The default should be OK in many systems, larger values may be required for accurate derivatives.
624    #        vertices_names: Used to specify the k-path for the interpolated QP band structure
625    #            when ``ks_ebands_kpath`` is None.
626    #            It's a list of tuple, each tuple is of the form (kfrac_coords, kname) where
627    #            kfrac_coords are the reduced coordinates of the k-point and kname is a string with the name of
628    #            the k-point. Each point represents a vertex of the k-path. ``line_density`` defines
629    #            the density of the sampling. If None, the k-path is automatically generated according
630    #            to the point group of the system.
631    #        line_density: Number of points in the smallest segment of the k-path. Used with ``vertices_names``.
632    #        filter_params: TO BE DESCRIBED
633    #        verbose: Verbosity level
634
635    #    Returns:
636    #    """
637    #    # Get symmetries from abinit spacegroup (read from file).
638    #    abispg = self.structure.abi_spacegroup
639    #    fm_symrel = [s for (s, afm) in zip(abispg.symrel, abispg.symafm) if afm == 1]
640
641    #    phbst_file, phdos_file = ddb.anaget_phbst_and_phdos_files(nqsmall=0, ndivsm=10, asr=2, chneut=1, dipdip=1,
642    #        dos_method="tetra", lo_to_splitting="automatic", ngqpt=None, qptbounds=None, anaddb_kwargs=None, verbose=0,
643    #        mpi_procs=1, workdir=None, manager=None)
644
645    #    phbands = phbst_file.phbands
646    #    phbst_file.close()
647
648    #    # Read qibz and ab-initio linewidths from file.
649    #    qcoords_ibz = self.reader.read_value("qibz")
650    #    data_ibz = self.reader.read_value("phgamma_qibz") * units.Ha_to_eV
651    #    import matplotlib.pyplot as plt
652    #    plt.plot(data_ibz[0])
653    #    plt.show()
654
655    #    # Build interpolator.
656    #    from abipy.core.skw import SkwInterpolator
657    #    cell = (self.structure.lattice.matrix, self.structure.frac_coords, self.structure.atomic_numbers)
658
659    #    has_timrev = True
660    #    fermie, nelect = 0.0, 3 * len(self.structure)
661    #    skw = SkwInterpolator(lpratio, qcoords_ibz, data_ibz, fermie, nelect,
662    #                          cell, fm_symrel, has_timrev,
663    #                          filter_params=filter_params, verbose=verbose)
664
665    #    # Interpolate and set linewidths.
666    #    qfrac_coords = [q.frac_coords for q in phbands.qpoints]
667    #    phbands.linewidths = skw.interp_kpts(qfrac_coords).eigens
668
669    #    return phbands
670
671    @add_fig_kwargs
672    def plot_eph_strength(self, what_list=("phbands", "gamma", "lambda"), ax_list=None,
673                          ylims=None, label=None, fontsize=12, **kwargs):
674        """
675        Plot phonon bands with EPH coupling strength lambda(q, nu) and lambda(q, nu)
676        These values have been Fourier interpolated by Abinit.
677
678        Args:
679            what_list: ``phfreqs`` for phonons, `lambda`` for the eph coupling strength,
680                ``gamma`` for phonon linewidths.
681            ax_list: List of |matplotlib-Axes| (same length as what_list)
682                or None if a new figure should be created.
683            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
684                or scalar e.g. ``left``. If left (right) is None, default values are used
685            label: String used to label the plot in the legend.
686            fontsize: Legend and title fontsize.
687
688        Returns: |matplotlib-Figure|
689        """
690        what_list = list_strings(what_list)
691        nrows, ncols = len(what_list), 1
692        ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
693                                                sharex=True, sharey=False, squeeze=False)
694        ax_list = np.array(ax_list).ravel()
695        units = "eV"
696
697        for i, (ax, what) in enumerate(zip(ax_list, what_list)):
698            # Decorate the axis (e.g add ticks and labels).
699            self.phbands.decorate_ax(ax, units="")
700
701            if what == "phbands":
702                # Plot phonon bands
703                self.phbands.plot(ax=ax, units=units, show=False)
704            else:
705                # Add eph coupling.
706                if what == "lambda":
707                    yvals = self.reader.read_phlambda_qpath()
708                    ylabel = r"$\lambda(q,\nu)$"
709                elif what == "gamma":
710                    yvals = self.reader.read_phgamma_qpath()
711                    ylabel = r"$\gamma(q,\nu)$ (eV)"
712                else:
713                    raise ValueError("Invalid value for what: `%s`" % str(what))
714
715                style = dict(
716                    linestyle=kwargs.pop("linestyle", "-"),
717                    color=kwargs.pop("color", "k"),
718                    linewidth=kwargs.pop("linewidth", 1),
719                )
720
721                xvals = np.arange(len(self.phbands.qpoints))
722                for nu in self.phbands.branches:
723                    ax.plot(xvals, yvals[:, nu],
724                            label=label if (nu == 0 and label) else None,
725                            **style)
726
727                ax.set_ylabel(ylabel)
728
729        set_axlims(ax, ylims, "y")
730        if label: ax.legend(loc="best", shadow=True, fontsize=fontsize)
731
732        return fig
733
734    @add_fig_kwargs
735    def plot(self, what="gamma", units="eV", scale=None, alpha=0.6, ylims=None, ax=None, colormap="jet", **kwargs):
736        """
737        Plot phonon bands with gamma(q, nu) or lambda(q, nu) depending on the vaue of `what`.
738
739        Args:
740            what: ``lambda`` for eph coupling strength, ``gamma`` for phonon linewidths.
741            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
742                Case-insensitive.
743            scale: float used to scale the marker size.
744            alpha: The alpha blending value for the markers between 0 (transparent) and 1 (opaque)
745            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
746                   or scalar e.g. ``left``. If left (right) is None, default values are used
747            ax: |matplotlib-Axes| or None if a new figure should be created.
748            colormap: matplotlib color map.
749
750        Returns: |matplotlib-Figure|
751        """
752        ax, fig, plt = get_ax_fig_plt(ax=ax)
753        cmap = plt.get_cmap(colormap)
754
755        # Plot phonon bands.
756        self.phbands.plot(ax=ax, units=units, show=False)
757
758        # Add eph coupling.
759        xvals = np.arange(len(self.phbands.qpoints))
760        wvals = self.phbands.phfreqs * abu.phfactor_ev2units(units)
761
762        # Sum contributions over nsppol (if spin-polarized)
763        # TODO units
764        gammas = self.reader.read_phgamma_qpath()
765        lambdas = self.reader.read_phlambda_qpath()
766        if what == "lambda":
767            scale = 500 if scale is None else float(scale)
768            sqn = scale * np.abs(lambdas)
769            cqn = gammas
770        elif what == "gamma":
771            scale = 10 ** 6 if scale is None else float(scale)
772            sqn = scale * np.abs(gammas)
773            cqn = lambdas
774        else:
775            raise ValueError("Invalid what: `%s`" % str(what))
776
777        vmin, vmax = cqn.min(), cqn.max()
778
779        sc = ax.scatter(np.tile(xvals, len(self.phbands.branches)),
780                        wvals.T, # [q, nu] --> [nu, q]
781                        s=sqn.T,
782                        c=cqn.T,
783                        vmin=vmin, vmax=vmax,
784                        cmap=cmap,
785                        marker="o",
786                        alpha=alpha,
787                        #label=term if ib == 0 else None
788        )
789
790        # Make a color bar
791        #plt.colorbar(sc, ax=ax, orientation="horizontal", pad=0.2)
792        set_axlims(ax, ylims, "y")
793
794        return fig
795
796    @add_fig_kwargs
797    def plot_a2f_interpol(self, units="eV", ylims=None, fontsize=8, **kwargs):
798        """
799        Compare ab-initio a2F(w) with interpolated values.
800
801        Args:
802            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
803                Case-insensitive.
804            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
805                or scalar e.g. ``left``. If left (right) is None, default values are used
806            fontsize: Legend and title fontsize
807
808        Returns: |matplotlib-Figure|
809        """
810        what_list = ["a2f", "lambda"]
811        nrows, ncols = len(what_list), 1
812        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
813                                                sharex=True, sharey=False, squeeze=False)
814        ax_list = np.array(ax_list).ravel()
815
816        styles = dict(
817            qcoarse={"linestyle": "--", "color": "b"},
818            qintp={"linestyle": "-", "color": "r"},
819        )
820
821        for ix, (ax, what) in enumerate(zip(ax_list, what_list)):
822            for qsamp in ["qcoarse", "qintp"]:
823                a2f = self.get_a2f_qsamp(qsamp)
824                a2f.plot(what=what, ax=ax, units=units, ylims=ylims, fontsize=fontsize,
825                         label=qsamp if ix == 0 else None,
826                         show=False, **styles[qsamp])
827
828        return fig
829
830    @add_fig_kwargs
831    def plot_with_a2f(self, what="gamma", units="eV", qsamp="qintp", phdos=None, ylims=None, **kwargs):
832        """
833        Plot phonon bands with lambda(q, nu) + a2F(w) + phonon DOS.
834
835        Args:
836            what: ``lambda`` for eph coupling strength, ``gamma`` for phonon linewidths.
837            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz"). Case-insensitive.
838            qsamp:
839            phdos: |PhononDos| object. Used to plot the PhononDos on the right.
840            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
841                or scalar e.g. ``left``. If left (right) is None, default values are used
842
843        Returns: |matplotlib-Figure|
844        """
845        # Max three additional axes with [a2F, a2F_tr, DOS]
846        ncols = 2
847        width_ratios = [1, 0.2]
848        if self.has_a2ftr:
849            ncols += 1
850            width_ratios.append(0.2)
851
852        if phdos is not None:
853            phdos = PhononDos.as_phdos(phdos)
854            ncols += 1
855            width_ratios.append(0.2)
856
857        # Build grid plot.
858        import matplotlib.pyplot as plt
859        from matplotlib.gridspec import GridSpec
860        fig = plt.figure()
861        gspec = GridSpec(1, ncols, width_ratios=width_ratios, wspace=0.05)
862        ax_phbands = plt.subplot(gspec[0])
863
864        ax_doses = []
865        for i in range(ncols - 1):
866            ax = plt.subplot(gspec[i + 1], sharey=ax_phbands)
867            ax.grid(True)
868            set_axlims(ax, ylims, "y")
869            ax_doses.append(ax)
870
871        # Plot phonon bands with markers.
872        self.plot(what=what, units=units, ylims=ylims, ax=ax_phbands, show=False)
873
874        # Plot a2F(w)
875        a2f = self.get_a2f_qsamp(qsamp)
876        ax = ax_doses[0]
877        a2f.plot(units=units, exchange_xy=True, ylims=ylims, ax=ax, show=False)
878        ax.yaxis.set_ticks_position("right")
879        #ax.yaxis.set_label_position("right")
880        #ax.tick_params(labelbottom='off')
881        ax.set_ylabel("")
882
883        # Plot a2Ftr(w)
884        ix = 1
885        if self.has_a2ftr:
886            ax = ax_doses[ix]
887            a2ftr = self.get_a2ftr_qsamp(qsamp)
888            self.a2ftr.plot(units=units, exchange_xy=True, ylims=ylims, ax=ax, show=False)
889            ax.yaxis.set_ticks_position("right")
890            #ax.yaxis.set_label_position("right")
891            #ax.tick_params(labelbottom='off')
892            ax.set_ylabel("")
893            ix += 1
894
895        # Plot DOS g(w)
896        if phdos is not None:
897            ax = ax_doses[ix]
898            phdos.plot_dos_idos(ax=ax, exchange_xy=True, what="d", color="k", linestyle="-")
899            ax.yaxis.set_ticks_position("right")
900            #ax.yaxis.set_label_position("right")
901            #ax.tick_params(labelbottom='off')
902            ax.set_xlabel(r"$F(\omega)$")
903            #ax.set_ylabel("")
904
905        return fig
906
907    def yield_figs(self, **kwargs):  # pragma: no cover
908        """
909        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
910        Used in abiview.py to get a quick look at the results.
911        """
912        #yield self.plot(show=False)
913        #yield self.plot_eph_strength(show=False)
914        #yield self.plot_with_a2f(show=False)
915
916        #for qsamp in ["qcoarse", "qintp"]:
917        for qsamp in ["qcoarse",]:
918            a2f = self.get_a2f_qsamp(qsamp)
919            yield a2f.plot_with_lambda(title="q-sampling: %s (%s)" % (str(a2f.ngqpt), qsamp), show=False)
920
921        #yield self.plot_nuterms(show=False)
922        #yield self.plot_a2(show=False)
923        #yield self.plot_tc_vs_mustar(show=False)
924
925        #if self.has_a2ftr:
926        #    ncfile.a2ftr.plot();
927
928    def write_notebook(self, nbpath=None):
929        """
930        Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current
931        working directory is created. Return path to the notebook.
932        """
933        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
934
935        nb.cells.extend([
936            nbv.new_code_cell("ncfile = abilab.abiopen('%s')" % self.filepath),
937            nbv.new_code_cell("print(ncfile)"),
938            nbv.new_code_cell("ncfile.ebands.plot();"),
939            nbv.new_code_cell("ncfile.plot();"),
940            #nbv.new_code_cell("ncfile.plot_phlinewidths();"),
941            nbv.new_code_cell("ncfile.plot_with_a2f();"),
942            nbv.new_code_cell("ncfile.a2f.plot();"),
943        ])
944
945        if self.has_a2ftr:
946            nb.cells.extend([
947                nbv.new_code_cell("ncfile.a2ftr.plot();"),
948                #nbv.new_code_cell("ncfile.plot_with_a2ftr();"),
949            ])
950
951        return self._write_nb_nbpath(nb, nbpath)
952
953
954class A2fRobot(Robot, RobotWithEbands, RobotWithPhbands):
955    """
956    This robot analyzes the results contained in multiple A2F.nc files.
957
958    .. rubric:: Inheritance Diagram
959    .. inheritance-diagram:: A2fRobot
960    """
961    #TODO: Method to plot the convergence of DOS(e_F)
962    EXT = "A2F"
963
964    linestyle_qsamp = dict(qcoarse="--", qintp="-")
965    marker_qsamp = dict(qcoarse="^", qintp="o")
966
967    #all_qsamps = ["qcoarse", "qintp"]
968    all_qsamps = ["qcoarse",]
969
970    def get_dataframe(self, abspath=False, with_geo=False, with_params=True, funcs=None):
971        """
972        Build and return a |pandas-DataFrame| with the most important results.
973
974        Args:
975            abspath: True if paths in index should be absolute. Default: Relative to getcwd().
976            with_geo: True if structure info should be added to the dataframe
977            funcs: Function or list of functions to execute to add more data to the DataFrame.
978                Each function receives a :class:`A2fFile` object and returns a tuple (key, value)
979                where key is a string with the name of column and value is the value to be inserted.
980            with_params: False to exclude calculation parameters from the dataframe.
981
982        Return: |pandas-DataFrame|
983        """
984        rows, row_names = [], []
985        for i, (label, ncfile) in enumerate(self.items()):
986            row_names.append(label)
987            d = OrderedDict()
988
989            for qsamp in self.all_qsamps:
990                a2f = ncfile.get_a2f_qsamp(qsamp)
991                d["lambda_" + qsamp] = a2f.lambda_iso
992                d["omegalog_" + qsamp] = a2f.omega_log
993
994                # Add transport properties.
995                if ncfile.has_a2ftr:
996                    for qsamp in self.all_qsamps:
997                        a2ftr = ncfile.get_a2ftr_qsamp(qsamp)
998                        d["lambdatr_avg_" + qsamp] = a2f.lambda_tr
999
1000            # Add info on structure.
1001            if with_geo:
1002                d.update(ncfile.structure.get_dict4pandas(with_spglib=True))
1003
1004            # Add convergence parameters
1005            if with_params:
1006                d.update(ncfile.params)
1007
1008            # Execute functions.
1009            if funcs is not None: d.update(self._exec_funcs(funcs, ncfile))
1010            rows.append(d)
1011
1012        import pandas as pd
1013        row_names = row_names if not abspath else self._to_relpaths(row_names)
1014        return pd.DataFrame(rows, index=row_names, columns=list(rows[0].keys()))
1015
1016    @add_fig_kwargs
1017    def plot_lambda_convergence(self, what="lambda", sortby=None, hue=None, ylims=None, fontsize=8,
1018                                colormap="jet", **kwargs):
1019        """
1020        Plot the convergence of the lambda(q, nu) parameters wrt to the ``sortby`` parameter.
1021
1022        Args:
1023            what: "lambda" for eph strength, gamma for phonon linewidths.
1024            sortby: Define the convergence parameter, sort files and produce plot labels.
1025                Can be None, string or function. If None, no sorting is performed.
1026                If string and not empty it's assumed that the abifile has an attribute
1027                with the same name and `getattr` is invoked.
1028                If callable, the output of sortby(abifile) is used.
1029            hue: Variable that define subsets of the data, which will be drawn on separate lines.
1030                Accepts callable or string
1031                If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
1032                If callable, the output of hue(abifile) is used.
1033            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
1034                   or scalar e.g. ``left``. If left (right) is None, default values are used
1035            fontsize: Legend and title fontsize.
1036            colormap: matplotlib color map.
1037
1038        Returns: |matplotlib-Figure|
1039        """
1040        # Build (1, ngroups) grid plot.
1041        if hue is None:
1042            labels_ncfiles_params = self.sortby(sortby, unpack=False)
1043            nrows, ncols = 1, 1
1044        else:
1045            groups = self.group_and_sortby(hue, sortby)
1046            nrows, ncols = 1, len(groups)
1047
1048        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
1049                                               sharex=True, sharey=False, squeeze=False)
1050        cmap = plt.get_cmap(colormap)
1051
1052        if hue is None:
1053            # Plot all results on the same figure with different color.
1054            for i, (label, ncfile, param) in enumerate(labels_ncfiles_params):
1055                ncfile.plot_eph_strength(what_list=what,
1056                        ax_list=[ax_mat[0, 0]],
1057                        ylims=ylims,
1058                        label=self.sortby_label(sortby, param),
1059                        color=cmap(i / len(self)), fontsize=fontsize,
1060                        show=False,
1061                        )
1062        else:
1063            # ngroup figures
1064            for ig, g in enumerate(groups):
1065                ax = ax_mat[0, ig]
1066                label = "%s: %s" % (self._get_label(hue), g.hvalue)
1067                for ifile, ncfile in enumerate(g.abifiles):
1068                    ncfile.plot_eph_strength(what_list=what,
1069                        ax_list=[ax],
1070                        ylims=ylims,
1071                        label=label,
1072                        color=cmap(ifile / len(g)), fontsize=fontsize,
1073                        show=False,
1074                        )
1075                if ig != 0:
1076                    set_visible(ax, False, "ylabel")
1077
1078        return fig
1079
1080    @add_fig_kwargs
1081    def plot_a2f_convergence(self, sortby=None, hue=None, qsamps="all", xlims=None,
1082                            fontsize=8, colormap="jet", **kwargs):
1083        """
1084        Plot the convergence of the Eliashberg function wrt to the ``sortby`` parameter.
1085
1086        Args:
1087            sortby: Define the convergence parameter, sort files and produce plot labels.
1088                Can be None, string or function. If None, no sorting is performed.
1089                If string and not empty it's assumed that the abifile has an attribute
1090                with the same name and `getattr` is invoked.
1091                If callable, the output of sortby(abifile) is used.
1092            hue: Variable that define subsets of the data, which will be drawn on separate lines.
1093                Accepts callable or string
1094                If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
1095                If callable, the output of hue(abifile) is used.
1096            qsamps:
1097            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
1098                   or scalar e.g. ``left``. If left (right) is None, default values are used.
1099            fontsize: Legend and title fontsize.
1100            colormap: matplotlib color map.
1101
1102        Returns: |matplotlib-Figure|
1103        """
1104        qsamps = self.all_qsamps if qsamps == "all" else list_strings(qsamps)
1105        #qsamps = ["qcoarse"]
1106
1107        # Build (2, ngroups) grid plot.
1108        if hue is None:
1109            labels_ncfiles_params = self.sortby(sortby, unpack=False)
1110            nrows, ncols = len(qsamps), 1
1111        else:
1112            groups = self.group_and_sortby(hue, sortby)
1113            nrows, ncols = len(qsamps), len(groups)
1114
1115        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
1116                                               sharex=True, sharey=False, squeeze=False)
1117        cmap = plt.get_cmap(colormap)
1118
1119        for i, qsamp in enumerate(qsamps):
1120            if hue is None:
1121                ax = ax_mat[i, 0]
1122                for j, (label, ncfile, param) in enumerate(labels_ncfiles_params):
1123                    ncfile.get_a2f_qsamp(qsamp).plot(what="a2f", ax=ax,
1124                       label=self.sortby_label(sortby, param) + " " + qsamp,
1125                       color=cmap(j / len(self)), fontsize=fontsize,
1126                       linestyle=self.linestyle_qsamp[qsamp],
1127                       show=False,
1128                    )
1129                set_axlims(ax, xlims, "x")
1130            else:
1131                for ig, g in enumerate(groups):
1132                    ax = ax_mat[i, ig]
1133                    label = "%s: %s" % (self._get_label(hue), g.hvalue) + " " + qsamp
1134                    for ncfile in g.abifiles:
1135                        ncfile.get_a2f_qsamp(qsamp).plot(what="a2f", ax=ax,
1136                            label=label,
1137                            color=cmap(ig / len(g)), fontsize=fontsize,
1138                            linestyle=self.linestyle_qsamp[qsamp],
1139                            show=False,
1140                        )
1141                    set_axlims(ax, xlims, "x")
1142                    if ig != 0:
1143                        set_visible(ax, False, "ylabel")
1144
1145                if i != len(qsamps) - 1:
1146                    set_visible(ax, False, "xlabel")
1147
1148        return fig
1149
1150    @add_fig_kwargs
1151    def plot_a2fdata_convergence(self, sortby=None, hue=None, qsamps="all", what_list=("lambda_iso", "omega_log"),
1152                                 fontsize=8, **kwargs):
1153        """
1154        Plot the convergence of the isotropic lambda and omega_log wrt the ``sortby`` parameter.
1155
1156        Args:
1157            sortby: Define the convergence parameter, sort files and produce plot labels.
1158                Can be None, string or function. If None, no sorting is performed.
1159                If string and not empty it's assumed that the abifile has an attribute
1160                with the same name and `getattr` is invoked.
1161                If callable, the output of sortby(abifile) is used.
1162            hue: Variable that define subsets of the data, which will be drawn on separate lines.
1163                Accepts callable or string
1164                If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
1165                If callable, the output of hue(abifile) is used.
1166            qsamps:
1167            what_list:
1168            fontsize: Legend and title fontsize.
1169
1170        Returns: |matplotlib-Figure|
1171        """
1172        what_list = list_strings(what_list)
1173
1174        # Build grid with (n, 1) plots.
1175        nrows, ncols = len(what_list), 1
1176        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
1177                                                sharex=True, sharey=False, squeeze=False)
1178        ax_list = np.array(ax_list).ravel()
1179
1180        if hue is None:
1181            labels, ncfiles, params = self.sortby(sortby, unpack=True)
1182        else:
1183            groups = self.group_and_sortby(hue, sortby)
1184
1185        qsamps = self.all_qsamps if qsamps == "all" else list_strings(qsamps)
1186        marker = kwargs.pop("marker", "o")
1187
1188        for ix, (ax, what) in enumerate(zip(ax_list, what_list)):
1189            #ax.set_title(what, fontsize=fontsize)
1190            if hue is None:
1191                params_are_string = duck.is_string(params[0])
1192                xvals = params if not params_are_string else range(len(params))
1193                l = None
1194                for iq, qsamp in enumerate(qsamps):
1195                    a2f_list = [ncfile.get_a2f_qsamp(qsamp) for ncfile in ncfiles]
1196                    yvals = [getattr(a2f, what) for a2f in a2f_list]
1197                    l = ax.plot(xvals, yvals,
1198                                marker=self.marker_qsamp[qsamp],
1199                                linestyle=self.linestyle_qsamp[qsamp],
1200                                color=None if iq == 0 else l[0].get_color(),
1201                                )
1202                    if params_are_string:
1203                        ax.set_xticks(xvals)
1204                        ax.set_xticklabels(params, fontsize=fontsize)
1205            else:
1206                for g in groups:
1207                    for iq, qsamp in enumerate(qsamps):
1208                        a2f_list = [ncfile.get_a2f_qsamp(qsamp) for ncfile in g.abifiles]
1209                        yvals = [getattr(a2f, what) for a2f in a2f_list]
1210                        label = "%s: %s" % (self._get_label(hue), g.hvalue) if iq == 0 else None
1211                        l = ax.plot(g.xvalues, yvals, label=label,
1212                                    marker=self.marker_qsamp[qsamp],
1213                                    linestyle=self.linestyle_qsamp[qsamp],
1214                                    color=None if iq == 0 else l[0].get_color(),
1215                                    )
1216
1217            ax.grid(True)
1218            ax.set_ylabel(_LATEX_LABELS[what])
1219            if ix == len(what_list) - 1:
1220                ax.set_xlabel("%s" % self._get_label(sortby))
1221                if sortby is None: rotate_ticklabels(ax, 15)
1222            if hue is not None:
1223                ax.legend(loc="best", fontsize=fontsize, shadow=True)
1224
1225        return fig
1226
1227    @add_fig_kwargs
1228    def gridplot_a2f(self, xlims=None, fontsize=8, sharex=True, sharey=True, **kwargs):
1229        """
1230        Plot grid with a2F(w) and lambda(w) for all files treated by the robot.
1231
1232        Args:
1233            xlims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
1234                or scalar e.g. ``left``. If left (right) is None, default values are used
1235            sharex, sharey: True to share x- and y-axis.
1236            fontsize: Legend and title fontsize
1237        """
1238        return self._gridplot_a2f_what("a2f", xlims=xlims, fontsize=fontsize, sharex=sharex, sharey=sharey, **kwargs)
1239
1240    #@add_fig_kwargs
1241    #def gridplot_a2ftr(self, xlims=None, fontsize=8, sharex=True, sharey=True, **kwargs):
1242    #    return self._gridplot_a2f_what("a2ftr", xlims=xlims, fontsize=fontsize, sharex=sharex, sharey=sharey, **kwargs)
1243
1244    def _gridplot_a2f_what(self, what, qsamps="all", xlims=None, fontsize=8, sharex=True, sharey=True, **kwargs):
1245        """Internal method to plot a2F or a2f_tr"""
1246        nrows, ncols, nplots = 1, 1, len(self)
1247        if nplots > 1:
1248            ncols = 2
1249            nrows = nplots // ncols + nplots % ncols
1250
1251        # Build grid plot
1252        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
1253                                                sharex=sharex, sharey=sharey, squeeze=False)
1254        ax_list = ax_list.ravel()
1255        # don't show the last ax if nplots is odd.
1256        if nplots % ncols != 0: ax_list[-1].axis("off")
1257
1258        qsamps = self.all_qsamps if qsamps == "all" else list_strings(qsamps)
1259        #qsamps = ["qcoarse"]
1260        for qsamp in qsamps:
1261            if what == "a2f":
1262                a2f_list = [ncfile.get_a2f_qsamp(qsamp) for ncfile in self.abifiles]
1263            elif what == "a2ftr":
1264                a2f_list = self.get_a2ftr_qsamp(qsamp)
1265            else:
1266                raise ValueError("Invalid value for what: `%s`" % what)
1267
1268            a2f_list = [ncfile.get_a2f_qsamp(qsamp) for ncfile in self.abifiles]
1269
1270            for i, (a2f, ax, title) in enumerate(zip(a2f_list, ax_list, self.keys())):
1271                irow, icol = divmod(i, ncols)
1272                # FIXME: Twinx is problematic
1273                a2f.plot_with_lambda(ax=ax, show=False,
1274                                     linestyle=self.linestyle_qsamp[qsamp],
1275                                     )
1276
1277                set_axlims(ax, xlims, "x")
1278                ax.set_title(title, fontsize=fontsize)
1279                if (irow, icol) != (0, 0):
1280                    set_visible(ax, False, "ylabel")
1281                if irow != nrows - 1:
1282                    set_visible(ax, False, "xlabel")
1283
1284        return fig
1285
1286    #@add_fig_kwargs
1287    #def plot_a2ftr_convergence(self, sortby=None, qsamps="all", ax=None, xlims=None,
1288    #                           fontsize=8, colormap="jet", **kwargs):
1289    #    qsamps = self.all_qsamps if qsamps == "all" else list_strings(qsamps)
1290    #    ax, fig, plt = get_ax_fig_plt(ax=ax)
1291    #    cmap = plt.get_cmap(colormap)
1292    #    for i, (label, ncfile, param) in enumerate(self.sortby(sortby)):
1293    #        for qsamp in qsamps:
1294    #            ncfile.get_a2ftr_qsamp(qsamp).plot(
1295    #                    ax=ax=ax,
1296    #                    label=self.sortby_label(sortby, param),
1297    #                    color=cmap(i / len(self)),
1298    #                    show=False,
1299    #                    )
1300    #    set_axlims(ax, xlims, "x")
1301    #    return fig
1302
1303    def yield_figs(self, **kwargs):  # pragma: no cover
1304        """
1305        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
1306        """
1307        #yield self.plot_lambda_convergence(show=False)
1308        yield self.plot_a2f_convergence(show=False)
1309        yield self.plot_a2fdata_convergence(show=False)
1310        yield self.gridplot_a2f(show=False)
1311
1312    def write_notebook(self, nbpath=None):
1313        """
1314        Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current
1315        working directory is created. Return path to the notebook.
1316        """
1317        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
1318
1319        args = [(l, f.filepath) for l, f in self.items()]
1320        nb.cells.extend([
1321            #nbv.new_markdown_cell("# This is a markdown cell"),
1322            nbv.new_code_cell("robot = abilab.A2fRobot(*%s)\nrobot.trim_paths()\nrobot" % str(args)),
1323            nbv.new_code_cell("data = robot.get_dataframe()\ndata"),
1324            nbv.new_code_cell("robot.plot_lambda_convergence();"),
1325            nbv.new_code_cell("robot.plot_a2f_convergence();"),
1326        ])
1327
1328        if all(ncf.has_a2ftr for ncf in self.abifiles):
1329            nb.cells.extend([
1330                nbv.new_code_cell("robot.plot_a2ftr_convergence();"),
1331            ])
1332
1333        # Mixins.
1334        nb.cells.extend(self.get_baserobot_code_cells())
1335        nb.cells.extend(self.get_ebands_code_cells())
1336        nb.cells.extend(self.get_phbands_code_cells())
1337
1338        return self._write_nb_nbpath(nb, nbpath)
1339
1340
1341class A2fReader(BaseEphReader):
1342    """
1343    Reads data from the EPH.nc file and constructs objects.
1344
1345    .. rubric:: Inheritance Diagram
1346    .. inheritance-diagram:: A2fReader
1347    """
1348    def read_edos(self):
1349        """
1350        Read the |ElectronDos| used to compute EPH quantities.
1351        """
1352        mesh = self.read_value("edos_mesh") * units.Ha_to_eV
1353        # [nsppol+1, nw] arrays with TOT_DOS, Spin_up, Spin_down in a.u.
1354        var = self.read_variable("edos_dos")
1355        if var.shape[0] == 3:
1356            # Spin polarized. Extract up-down components.
1357            spin_dos = var[1:, :] / units.Ha_to_eV
1358        else:
1359            # Spin unpolarized. Extract Tot DOS
1360            spin_dos = var[0, :] / units.Ha_to_eV
1361
1362        #spin_idos = self.read_variable("edos_idos")[1:, :] / units.Ha_to_eV
1363        nelect = self.read_value("number_of_electrons")
1364        fermie = self.read_value("fermi_energy") * units.Ha_to_eV
1365
1366        return ElectronDos(mesh, spin_dos, nelect, fermie=fermie)
1367
1368    def read_phbands_qpath(self):
1369        """
1370        Read and return a |PhononBands| object with frequencies computed along the q-path.
1371        """
1372        structure = self.read_structure()
1373
1374        # Build the list of q-points
1375        qpoints = Kpath(structure.reciprocal_lattice,
1376                        frac_coords=self.read_value("qpath"),
1377                        weights=None, names=None, ksampling=None)
1378
1379        #nctkarr_t('phfreq_qpath', "dp", "natom3, nqpath"),&
1380        phfreqs = self.read_value("phfreq_qpath") * units.Ha_to_eV
1381        phdispl_cart = self.read_value("phdispl_cart_qpath", cmode="c") * units.bohr_to_ang
1382
1383        linewidths = self.read_phgamma_qpath()
1384        if self.read_nsppol() == 2:
1385            # We have spin-resolved linewidths, sum over spins here.
1386            linewidths = linewidths.sum(axis=0)
1387
1388        amu_list = self.read_value("atomic_mass_units", default=None)
1389        if amu_list is not None:
1390            atom_species = self.read_value("atomic_numbers")
1391            amu = {at: a for at, a in zip(atom_species, amu_list)}
1392        else:
1393            raise ValueError("atomic_mass_units is not present!")
1394            amu = None
1395
1396        return PhononBands(structure=structure,
1397                           qpoints=qpoints,
1398                           phfreqs=phfreqs,
1399                           phdispl_cart=phdispl_cart,
1400                           non_anal_ph=None,
1401                           amu=amu,
1402                           linewidths=linewidths,
1403                           )
1404
1405    def read_phlambda_qpath(self, sum_spin=True):
1406        """
1407        Reads the EPH coupling strength *interpolated* along the q-path.
1408
1409        Return:
1410            |numpy-array| with shape [nqpath, natom3] if not sum_spin else [nsppol, nqpath, natom3]
1411        """
1412        vals = self.read_value("phlambda_qpath")
1413        return vals if not sum_spin else vals.sum(axis=0)
1414
1415    def read_phgamma_qpath(self, sum_spin=True):
1416        """
1417        Reads the phonon linewidths (eV) *interpolated* along the q-path.
1418
1419        Return:
1420            |numpy-array| with shape [nqpath, natom3] if not sum_spin else [nsppol, nqpath, natom3]
1421        """
1422        vals = self.read_value("phgamma_qpath") * units.Ha_to_eV
1423        return vals if not sum_spin else vals.sum(axis=0)
1424
1425    def read_a2f(self, qsamp):
1426        """
1427        Read and return the Eliashberg function :class:`A2F`.
1428        """
1429        assert qsamp in ("qcoarse", "qintp")
1430        mesh = self.read_value("a2f_mesh_" + qsamp) * units.Ha_to_eV
1431        # C shape [nsppol, natom + 1, nomega]
1432        data = self.read_value("a2f_values_" + qsamp) # * 0.25
1433        values_spin = data[:, 0, :].copy()
1434        values_spin_nu = data[:, 1:, :].copy()
1435
1436        # Extract q-mesh and meta variables.
1437        ngqpt = self.ngqpt if qsamp == "qcoarse" else self.ph_ngqpt
1438        meta = {k: self.common_eph_params[k] for k in
1439                ["eph_intmeth", "eph_fsewin", "eph_fsmear", "eph_extrael", "eph_fermie"]}
1440
1441        return A2f(mesh, values_spin, values_spin_nu, ngqpt, meta)
1442
1443    #def read_a2ftr(self, qsamp):
1444    #    """Read and return the Eliashberg transport spectral function a2F_tr(w, x, x')."""
1445    #    assert qsamp in ("qcoarse", "qintp")
1446    #    mesh = self.read_value("a2ftr_mesh_" + qsamp) * units.Ha_to_eV
1447    #    # Transpose tensor components F --> C
1448    #    vals_in = self.read_value("a2ftr_in_" + qsamp)
1449    #    vals_out = self.read_value("a2ftr_out_" + qsamp)
1450    #    return A2ftr(mesh=mesh, vals_in, vals_out)
1451
1452    #def read_phgamma_ibz_data(self):
1453    #     ! linewidths in IBZ
1454    #     nctkarr_t('qibz', "dp", "number_of_reduced_dimensions, nqibz"), &
1455    #     nctkarr_t('wtq', "dp", "nqibz"), &
1456    #     nctkarr_t('phfreq_qibz', "dp", "natom3, nqibz"), &
1457    #     nctkarr_t('phdispl_cart_qibz', "dp", "two, natom3, natom3, nqibz"), &
1458    #     nctkarr_t('phgamma_qibz', "dp", "natom3, nqibz, number_of_spins"), &
1459    #     nctkarr_t('phlambda_qibz', "dp", "natom3, nqibz, number_of_spins") &
1460