1# coding: utf-8
2"""
3RTA.nc file.
4"""
5import numpy as np
6import abipy.core.abinit_units as abu
7
8from monty.functools import lazy_property
9#from monty.termcolor import cprint
10from monty.string import marquee, list_strings
11from abipy.core.mixins import AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter
12from abipy.electrons.ebands import ElectronsReader, RobotWithEbands
13from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt
14from abipy.abio.robots import Robot
15
16
17__all__ = [
18    "RtaFile",
19    "RtaRobot",
20]
21
22
23def eh2s(eh):
24    return {0: "n", 1: "p"}[eh]
25
26
27def irta2s(irta):
28    """Return RTA type from irta index."""
29    return {0: "SERTA", 1: "MRTA"}[irta]
30
31
32def style_for_irta(irta, with_marker=False):
33    """
34    Return dict with linestyle to plot SERTA/MRTA results
35    """
36    if irta == 0:
37        opts = dict(linewidth=1.0, linestyle="dotted")
38        if with_marker: opts["marker"] = "^"
39    elif irta == 1:
40        opts = dict(linewidth=1.0, linestyle="solid")
41        if with_marker: opts["marker"] = "v"
42    else:
43        raise ValueError("Invalid value for irta: %s" % irta)
44
45    return opts
46
47
48def transptens2latex(what, component):
49    return {
50        "sigma": r"$\sigma_{%s}$" % component,
51        "seebeck": "$S_{%s}$" % component,
52        "kappa": r"$\kappa^{\mathcal{e}}_{%s}$" % component,
53        "pi": r"$\Pi_{%s}$" % component,
54        "zte": r"$\text{ZT}^e_{%s}$" % component,
55    }[what]
56
57
58def edos_infos(edos_intmeth, edos_broad):
59    s = {1: "Gaussian smearing Method",
60         2: "Linear Tetrahedron Method",
61        -2: "Linear Tetrahedron Method with Blochl's corrections",
62    }[edos_intmeth]
63    if (edos_intmeth == 1): s = "%s with broadening: %.1f (meV)" % edos_broad * abu.Ha_to_meV
64
65    return s
66
67
68def irta2latextau(irta, with_dollars=False):
69    s = r"\tau^{\mathbf{%s}}}" % irta2s(irta)
70    if with_dollars: s = "$%s$" % s
71    return s
72
73
74def x2_grid(what_list):
75    """
76    Build (x, 2) grid of plots or just (1, 1) depending of the length of what_list.
77
78    Return: (num_plots, ncols, nrows, what_list)
79    """
80    what_list = list_strings(what_list)
81    num_plots, ncols, nrows = len(what_list), 1, 1
82    if num_plots > 1:
83        ncols = 2
84        nrows = (num_plots // ncols) + (num_plots % ncols)
85
86    return num_plots, ncols, nrows, what_list
87
88
89class RtaFile(AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter):
90
91    @classmethod
92    def from_file(cls, filepath):
93        """Initialize the object from a netcdf file."""
94        return cls(filepath)
95
96    def __init__(self, filepath):
97        super().__init__(filepath)
98        self.reader = RtaReader(filepath)
99
100        self.nrta = self.reader.read_dimvalue("nrta")
101
102        #self.fermi = self.ebands.fermie * abu.eV_Ha
103        #self.transport_ngkpt = self.reader.read_value("transport_ngkpt")
104        #self.transport_extrael = self.reader.read_value("transport_extrael")
105        #self.transport_fermie = self.reader.read_value("transport_fermie")
106        self.sigma_erange = self.reader.read_value("sigma_erange")
107        #self.ebands.kpoints.ksampling.mpdivs
108
109        # Get position of CBM and VBM for each spin in eV
110        # nctkarr_t('vb_max', "dp", "nsppol")
111        self.vb_max_spin = self.reader.read_value("vb_max") * abu.Ha_to_eV
112        self.cb_min_spin = self.reader.read_value("cb_min") * abu.Ha_to_eV
113
114        # Get metadata for k-integration (coming from edos%ncwrite)
115        self.edos_intmeth = int(self.reader.read_value("edos_intmeth"))
116        self.edos_broad = self.reader.read_value("edos_broad")
117
118        # Store also the e-mesh n eV as it's often needed in the plotting routines.
119        # Several quantities are defined on this mesh.
120        self.edos_mesh_eV = self.reader.read_value("edos_mesh") * abu.Ha_to_eV
121
122    @property
123    def ntemp(self):
124        """Number of temperatures."""
125        return len(self.tmesh)
126
127    @property
128    def tmesh(self):
129        """Mesh with Temperatures in Kelvin."""
130        return self.reader.tmesh
131
132    @lazy_property
133    def assume_gap(self):
134        """True if we are dealing with a semiconductor. More precisely if all(sigma_erange) > 0."""
135        return bool(self.reader.rootgrp.variables["assume_gap"])
136
137    @lazy_property
138    def has_ibte(self):
139        """True if file contains IBTE results."""
140        return "ibte_sigma" in self.reader.rootgrp.variables
141
142    @lazy_property
143    def ebands(self):
144        """|ElectronBands| object."""
145        return self.reader.read_ebands()
146
147    @property
148    def structure(self):
149        """|Structure| object."""
150        return self.ebands.structure
151
152    @lazy_property
153    def params(self):
154        """:class:`OrderedDict` with parameters that might be subject to convergence studies."""
155        od = self.get_ebands_params()
156        return od
157
158    def __str__(self):
159        """String representation."""
160        return self.to_string()
161
162    def to_string(self, verbose=0):
163        """String representation."""
164        lines = []; app = lines.append
165
166        app(marquee("File Info", mark="="))
167        app(self.filestat(as_string=True))
168        app("")
169        app(self.structure.to_string(verbose=verbose, title="Structure"))
170        app("")
171        app(self.ebands.to_string(with_structure=False, verbose=verbose, title="KS Electron Bands"))
172        app("")
173
174        # Transport section.
175        app(marquee("Transport calculation", mark="="))
176        app("")
177        #app("edos_intmeth: %d" % self.edos_intmeth)
178        #app("edos_broad: %d (meV): " % (self.edos_broad * 1000))
179        app(edos_infos(self.edos_intmeth, self.edos_broad))
180        app("mesh step for energy integrals: %.1f (meV) " % ((self.edos_mesh_eV[1] - self.edos_mesh_eV[0]) * 1000))
181        app("")
182
183        components = ("xx", "yy", "zz") if verbose == 0 else ("xx", "yy", "zz", "xy", "xz", "yx")
184        for component in components:
185            for irta in range(self.nrta):
186                app("Mobility (%s Cartesian components), RTA type: %s" % (component, irta2s(irta)))
187                app("Temperature [K]     Electrons (cm^2/Vs)     Holes (cm^2/Vs)")
188                for itemp in range(self.ntemp):
189                    temp = self.tmesh[itemp]
190                    mobility_mu_e = self.get_mobility_mu(eh=0, itemp=itemp, component=component, irta=irta)
191                    mobility_mu_h = self.get_mobility_mu(eh=1, itemp=itemp, component=component, irta=irta)
192                    app("%14.1lf %18.6lf %18.6lf" % (temp, mobility_mu_e, mobility_mu_h))
193                app("")
194
195        return "\n".join(lines)
196
197    def get_mobility_mu(self, eh, itemp, component='xx', ef=None, irta=0, spin=0):
198        """
199        Get the mobility at the chemical potential Ef
200
201        Args:
202            eh: 0 for electrons, 1 for holes.
203            itemp: Index of the temperature.
204            component: Cartesian component to plot: "xx", "yy" "xy" ...
205            ef: Value of the doping in eV.
206                The default None uses the chemical potential at the temperature item as computed by Abinit.
207            spin: Spin index.
208        """
209        if ef is None: ef = self.reader.read_value('transport_mu_e')[itemp]
210        emesh, mobility = self.reader.read_mobility(eh, itemp, component, spin, irta=irta)
211
212        from scipy import interpolate
213        f = interpolate.interp1d(emesh, mobility)
214        return f(ef)
215
216    #def get_mobility_mu_dataframe(self, eh=0, component='xx', itemp=0, spin=0, **kwargs):
217
218    #def _select_itemps_labels(self, obj):
219    #   for it, temp in enumerate(self.tmesh):
220
221    def _add_vline_at_bandedge(self, ax, spin, cbm_or_vbm, **kwargs):
222        my_kwargs = dict(ymin=0, ymax=1, linewidth=1, linestyle="--")
223        my_kwargs.update(kwargs)
224        #from matplotlib.pyplot import text
225
226        if cbm_or_vbm in ("cbm", "both"):
227            x = self.cb_min_spin[spin]
228            ax.axvline(x=x, color="red", **my_kwargs) # label="CBM",
229            #ax.text(x, 5, "CBM", rotation=90, verticalalignment='center', fontsize=8)
230
231        if cbm_or_vbm in ("vbm", "both"):
232            x = self.vb_max_spin[spin]
233            ax.axvline(x=x, color="blue", **my_kwargs) # label="VBM",
234            #ax.text(x, 5, "VBM", rotation=90, verticalalignment='center', fontsize=8)
235
236    @add_fig_kwargs
237    def plot_edos(self, ax=None, fontsize=8, **kwargs):
238        """
239        Plot electron DOS
240
241        Args:
242            ax: |matplotlib-Axes| or None if a new figure should be created.
243            fontsize (int): fontsize for titles and legend
244
245        Return: |matplotlib-Figure|
246        """
247        ax, fig, plt = get_ax_fig_plt(ax=ax)
248
249        # Total DOS, spin up and spin down components in nsppol_plus1.
250        # nctkarr_t("edos_dos", "dp", "edos_nw, nsppol_plus1")
251        dos = self.reader.read_value("edos_dos") / abu.Ha_to_eV
252
253        # Plot total DOS.
254        ax.plot(self.edos_mesh_eV, dos[0], label="Total DOS", color="black", linewidth=1.0)
255
256        #idos = self.reader.read_value("edos_idos")
257        #ax.plot(self.edos_mesh_eV, idos[0], label="Total IDOS", color="black", linewidth=1.0)
258
259        if self.nsppol == 2:
260            ax.plot(self.edos_mesh_eV, + dos[1], color="red", linewidth=1, label="up")
261            ax.plot(self.edos_mesh_eV, - dos[2], color="blue", linewidth=1, label="down")
262
263        for spin in range(self.nsppol):
264            self._add_vline_at_bandedge(ax, spin, "both")
265
266        ax.grid(True)
267        ax.set_xlabel('Energy (eV)')
268        ax.set_ylabel('States/eV p.u.c')
269        ax.legend(loc="best", shadow=True, fontsize=fontsize)
270
271        if "title" not in kwargs:
272            title = r"$\frac{1}{N_k} \sum_{nk} \delta(\epsilon - \epsilon_{nk})$"
273            fig.suptitle(title, fontsize=fontsize)
274
275        return fig
276
277    @add_fig_kwargs
278    def plot_tau_isoe(self, ax_list=None, colormap="jet", fontsize=8, **kwargs):
279        r"""
280        Plot tau(e). Energy-dependent scattering rate defined by:
281
282            $\tau(\epsilon) = \frac{1}{N_k} \sum_{nk} \tau_{nk}\,\delta(\epsilon - \epsilon_{nk})$
283
284        Two differet subplots for SERTA and MRTA.
285
286        Args:
287            ax_list: List of |matplotlib-Axes| or None if a new figure should be created.
288            colormap:
289            fontsize (int): fontsize for titles and legend
290
291        Return: |matplotlib-Figure|
292        """
293        ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=self.nrta, ncols=1,
294                                                sharex=True, sharey=True, squeeze=False)
295        ax_list = ax_list.ravel()
296        cmap = plt.get_cmap(colormap)
297
298        # nctkarr_t('tau_dos', "dp", "edos_nw, ntemp, nsppol, nrta")
299        tau_dos = self.reader.read_value("tau_dos")
300
301        for irta, ax in enumerate(ax_list):
302            for spin in range(self.nsppol):
303                spin_sign = +1 if spin == 0 else -1
304                for it, temp in enumerate(self.tmesh):
305                    # Convert to femtoseconds
306                    ys = spin_sign * tau_dos[irta, spin, it] * abu.Time_Sec * 1e+15
307                    ax.plot(self.edos_mesh_eV , ys, c=cmap(it / self.ntemp),
308                            label="T = %dK" % temp if spin == 0 else None)
309
310            ax.grid(True)
311            ax.legend(loc="best", shadow=True, fontsize=fontsize)
312            if irta == (len(ax_list) - 1):
313                ax.set_xlabel('Energy (eV)')
314                ax.set_ylabel(r"$\tau(\epsilon)\, (fms)$")
315
316            self._add_vline_at_bandedge(ax, spin, "both")
317
318            ax.text(0.1, 0.9, irta2s(irta), fontsize=fontsize,
319                horizontalalignment='center', verticalalignment='center', transform=ax.transAxes,
320                bbox=dict(alpha=0.5))
321
322        if "title" not in kwargs:
323            title = r"$\tau(\epsilon) = \frac{1}{N_k} \sum_{nk} \tau_{nk}\,\delta(\epsilon - \epsilon_{nk})$"
324            fig.suptitle(title, fontsize=fontsize)
325
326        return fig
327
328    #@add_fig_kwargs
329    #def plot_vv_dos(self, component="xx", spin=0, ax=None, fontsize=8, **kwargs):
330
331    @add_fig_kwargs
332    def plot_vvtau_dos(self, component="xx", spin=0, ax=None, colormap="jet", fontsize=8, **kwargs):
333        r"""
334        Plot (v_i * v_j * tau) DOS.
335
336            $\frac{1}{N_k} \sum_{nk} v_i v_j \delta(\epsilon - \epsilon_{nk})$
337
338        Args:
339            component: Cartesian component to plot: "xx", "yy" "xy" ...
340            ax: |matplotlib-Axes| or None if a new figure should be created.
341            colormap: matplotlib colormap.
342            fontsize (int): fontsize for titles and legend
343
344        Return: |matplotlib-Figure|
345        """
346        i, j = abu.s2itup(component)
347
348        ax, fig, plt = get_ax_fig_plt(ax=ax)
349        cmap = plt.get_cmap(colormap)
350
351        for irta in range(self.nrta):
352            # nctkarr_t('vvtau_dos', "dp", "edos_nw, three, three, ntemp, nsppol, nrta")
353            var = self.reader.read_variable("vvtau_dos")
354            for itemp, temp in enumerate(self.tmesh):
355                vvtau_dos = var[irta, spin, itemp, j, i, :] / (2 * abu.Ha_s)
356                label = "T = %dK" % temp
357                if (itemp == 0): label = "%s (%s)" % (label, irta2s(irta))
358                if (irta == 0 and itemp > 0): label = None
359                ax.plot(self.edos_mesh_eV, vvtau_dos, c=cmap(itemp / self.ntemp), label=label, **style_for_irta(irta))
360
361                # This to plot the vv dos along without tau
362                #if itemp == 1:
363                #    # nctkarr_t('vv_dos', "dp", "edos_nw, three, three, nsppol"), &
364                #    vv_dos_var = self.reader.read_variable("vv_dos")
365                #    vv_dos = vv_dos_var[spin, j, i] # / (2 * abu.Ha_s)
366                #    ax.plot(self.edos_mesh_eV, vv_dos, c=cmap(itemp / self.ntemp), label='VVDOS' % temp)
367
368        self._add_vline_at_bandedge(ax, spin, "both")
369
370        ax.grid(True)
371        ax.set_xlabel('Energy (eV)')
372        ax.set_ylabel(r'$v_{%s} v_{%s} \tau$ DOS' % (component[0], component[1]))
373        ax.set_yscale('log')
374        ax.legend(loc="best", shadow=True, fontsize=fontsize)
375
376        if "title" not in kwargs:
377            vvt = r'v_{%s} v_{%s} \tau' % (component[0], component[1])
378            title = r"$\frac{1}{N_k} \sum_{nk} %s\,\delta(\epsilon - \epsilon_{nk})$" % vvt
379            fig.suptitle(title, fontsize=fontsize)
380
381        return fig
382
383    @add_fig_kwargs
384    def plot_mobility(self, eh=0, irta=0, component='xx', spin=0, ax=None,
385                      colormap='jet', fontsize=8, yscale="log", **kwargs):
386        """
387        Read the mobility from the netcdf file and plot it
388
389        Args:
390            component: Component to plot: "xx", "yy" "xy" ...
391            ax: |matplotlib-Axes| or None if a new figure should be created.
392            colormap: matplotlib colormap.
393            fontsize (int): fontsize for titles and legend
394
395        Return: |matplotlib-Figure|
396        """
397        ax, fig, plt = get_ax_fig_plt(ax=ax)
398        cmap = plt.get_cmap(colormap)
399
400        # nctkarr_t('mobility',"dp", "three, three, edos_nw, ntemp, two, nsppol, nrta")
401        mu_var = self.reader.read_variable("mobility")
402        i, j = abu.s2itup(component)
403
404        for irta in range(self.nrta):
405            for itemp, temp in enumerate(self.tmesh):
406                mu = mu_var[irta, spin, eh, itemp, :, j, i]
407                label = "T = %dK" % temp
408                if (itemp == 0): label = "%s (%s)" % (label, irta2s(irta))
409                if (irta == 0 and itemp > 0): label = None
410                ax.plot(self.edos_mesh_eV, mu, c=cmap(itemp / self.ntemp), label=label, **style_for_irta(irta))
411
412        self._add_vline_at_bandedge(ax, spin, "cbm" if eh == 0 else "vbm")
413
414        ax.grid(True)
415        ax.set_xlabel('Fermi level (eV)')
416        ax.set_ylabel(r'%s-mobility $\mu_{%s}(\epsilon_F)$ (cm$^2$/Vs)' % (eh2s(eh), component))
417        ax.set_yscale(yscale)
418        ax.legend(loc="best", shadow=True, fontsize=fontsize)
419
420        return fig
421
422    @add_fig_kwargs
423    def plot_transport_tensors_mu(self, component="xx", spin=0,
424                                  what_list=("sigma", "seebeck", "kappa", "pi"),
425                                  colormap="jet", fontsize=8, **kwargs):
426        """
427        Plot selected Cartesian components of transport tensors as a function
428        of the chemical potential mu at the given temperature.
429
430        Args:
431            ax_list: |matplotlib-Axes| or None if a new figure should be created.
432            fontsize: fontsize for legends and titles
433
434        Return: |matplotlib-Figure|
435        """
436        i, j = abu.s2itup(component)
437
438        num_plots, ncols, nrows, what_list = x2_grid(what_list)
439        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
440                                                sharex=True, sharey=False, squeeze=False)
441        ax_list = ax_list.ravel()
442        # don't show the last ax if numeb is odd.
443        if num_plots % ncols != 0: ax_list[-1].axis("off")
444
445        cmap = plt.get_cmap(colormap)
446
447        for iax, (what, ax) in enumerate(zip(what_list, ax_list)):
448            irow, icol = divmod(iax, ncols)
449            # nctkarr_t('seebeck', "dp", "three, three, edos_nw, ntemp, nsppol, nrta")
450            what_var = self.reader.read_variable(what)
451
452            for irta in range(self.nrta):
453                for itemp, temp in enumerate(self.tmesh):
454                    ys = what_var[irta, spin, itemp, :, j, i]
455                    label = "T = %dK" % temp
456                    if itemp == 0: label = "%s (%s)" % (label, irta2s(irta))
457                    if irta == 0 and itemp > 0: label = None
458                    ax.plot(self.edos_mesh_eV, ys, c=cmap(itemp / self.ntemp), label=label, **style_for_irta(irta))
459
460            ax.grid(True)
461            ax.set_ylabel(transptens2latex(what, component))
462
463            ax.legend(loc="best", fontsize=fontsize, shadow=True)
464            if irow == nrows - 1:
465                ax.set_xlabel(r"$\mu$ (eV)")
466
467            self._add_vline_at_bandedge(ax, spin, "both")
468
469        if "title" not in kwargs:
470            fig.suptitle("Transport tensors", fontsize=fontsize)
471
472        return fig
473
474    @add_fig_kwargs
475    def plot_ibte_vs_rta_rho(self, component="xx", fontsize=8, ax=None, **kwargs):
476        """
477        Plot resistivity computed with SERTA, MRTA and IBTE
478        """
479        #if not self.has_ibte:
480        #    cprint("Netcdf file does not contain IBTE results", "magenta")
481        #    return None
482
483        i = j = 0
484        rta_vals = self.reader.read_value("resistivity")
485        serta_t = rta_vals[0, :, j, i]
486        mrta_t = rta_vals[1, :, j, i]
487        ibte_vals = self.reader.read_value("ibte_rho")
488        ibte_t = ibte_vals[:, j, i]
489
490        ax, fig, plt = get_ax_fig_plt(ax=ax)
491        ax.grid(True)
492        ax.plot(self.tmesh, serta_t, label="serta")
493        ax.plot(self.tmesh, mrta_t, label="mrta")
494        ax.plot(self.tmesh, ibte_t, label="ibte")
495        ax.set_xlabel("Temperature (K)")
496        ax.set_ylabel(r"Resistivity ($\mu\Omega\;cm$)")
497        ax.legend(loc="best", shadow=True, fontsize=fontsize)
498
499        return fig
500
501    def yield_figs(self, **kwargs):  # pragma: no cover
502        """
503        Return figures plotting the transport data
504        """
505        yield self.plot_ibte_vs_rta_rho(show=False)
506        #yield self.plot_tau_isoe(show=False)
507        #yield self.plot_transport_tensors_mu(show=False)
508        #yield self.plot_edos(show=False)
509        #yield self.plot_vvtau_dos(show=False)
510        #yield self.plot_mobility(show=False, title="Mobility")
511        #if self.has_ibte:
512        #    yield self.plot_ibte_vs_rta_rho(show=False)
513
514    def close(self):
515        """Close the file."""
516        self.reader.close()
517
518    def write_notebook(self, nbpath=None):
519        """
520        Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current
521        working directory is created. Return path to the notebook.
522        """
523        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
524
525        nb.cells.extend([
526            nbv.new_code_cell("ncfile = abilab.abiopen('%s')" % self.filepath),
527            nbv.new_code_cell('components = ("xx", "yy", "zz", "xy", "xz", "yx")'),
528            nbv.new_code_cell("print(ncfile)"),
529            nbv.new_code_cell("ncfile.plot_edos();"),
530            nbv.new_code_cell("ncfile.plot_vvtau_dos();"),
531            nbv.new_code_cell("""
532for component in components:
533    ncfile.plot_mobility(component=component, spin=0);
534"""),
535        ])
536
537        return self._write_nb_nbpath(nb, nbpath)
538
539
540class RtaReader(ElectronsReader):
541    """
542    This class reads the results stored in the RTA.nc file
543    It provides helper function to access the most important quantities.
544    """
545    def __init__(self, filepath):
546        super().__init__(filepath)
547
548        self.nsppol = self.read_dimvalue('nsppol')
549        self.tmesh = self.read_value("kTmesh") / abu.kb_HaK
550
551    #def read_vvdos_tau(self, itemp, component='xx', spin=0, irta=0):
552    #    """
553    #    Read the group velocity density of states times lifetime for different temperatures
554    #    The vvdos_tau array has 4 dimensions (ntemp, 3, 3, nsppolplus1, nw)
555
556    #      1. the number of temperatures
557    #      2. 3x3 components of the tensor
558    #      3. the spin polarization + 1 for the sum
559    #      4. the number of frequencies
560    #    """
561    #    # nctkarr_t('vvtau_dos', "dp", "edos_nw, three, three, ntemp, nsppol, nrta")
562    #    i, j = abu.s2itup(component)
563    #    emesh = self.read_value("edos_mesh") * abu.Ha_eV
564    #    vals = self.read_variable("vvtau_dos")
565    #    vvtau_dos = vals[irta, spin, itemp, j, i, :] / (2 * abu.Ha_s)
566
567    #    return emesh, vvtau_dos
568
569    #def read_dos(self):
570    #    """
571    #    Read the density of states (in eV units)
572    #    """
573    #    # Total DOS, spin up and spin down component.
574    #    # nctkarr_t("edos_dos", "dp", "edos_nw, nsppol_plus1")
575    #    emesh = self.read_value("edos_mesh") * abu.Ha_to_eV
576    #    dos = self.read_value("edos_dos") / abu.Ha_to_eV
577    #    idos = self.read_value("edos_idos")
578
579    #    #return ElectronDos(mesh, spin_dos, nelect)
580    #    return emesh, dos, idos
581
582    #def read_onsager(self, itemp):
583    #    """
584    #    Read the Onsager coefficients computed in the transport driver in Abinit
585    #    """
586    #    # nctkarr_t('L0', "dp", "edos_nw, three, three, ntemp, nsppol, nrta"), &
587    #    L0 = np.moveaxis(self.read_variable("L0")[itemp,:], [0,1,2,3], [3,2,0,1])
588    #    L1 = np.moveaxis(self.read_variable("L1")[itemp,:], [0,1,2,3], [3,2,0,1])
589    #    L2 = np.moveaxis(self.read_variable("L2")[itemp,:], [0,1,2,3], [3,2,0,1])
590
591    #    return L0, L1, L2
592
593    #def read_transport(self, itemp):
594    #    # nctkarr_t('sigma',   "dp", "edos_nw, three, three, ntemp, nsppol, nrta"), &
595    #    sigma = np.moveaxis(self.read_variable("sigma")[itemp,:],     [0,1,2,3], [3,2,0,1])
596    #    kappa = np.moveaxis(self.read_variable("kappa")[itemp,:],     [0,1,2,3], [3,2,0,1])
597    #    seebeck = np.moveaxis(self.read_variable("seebeck")[itemp,:], [0,1,2,3], [3,2,0,1])
598    #    pi = np.moveaxis(self.read_variable("pi")[itemp,:],           [0,1,2,3], [3,2,0,1])
599    #    return sigma, kappa, seebeck, pi
600
601    def read_mobility(self, eh, itemp, component, spin, irta=0):
602        """
603        Read mobility from the RTA.nc file
604        The mobility is computed separately for electrons and holes.
605        """
606        # nctkarr_t('mobility',"dp", "three, three, edos_nw, ntemp, two, nsppol, nrta")
607        i, j = abu.s2itup(component)
608        wvals = self.read_variable("edos_mesh")
609        #wvals = self.read_value("edos_mesh") * abu.Ha_eV
610        mobility = self.read_variable("mobility")[irta, spin, eh, itemp, :, j, i]
611
612        return wvals, mobility
613
614
615class RtaRobot(Robot, RobotWithEbands):
616    """
617    This robot analyzes the results contained in multiple RTA.nc files.
618
619    .. rubric:: Inheritance Diagram
620    .. inheritance-diagram:: RtaRobot
621    """
622
623    EXT = "RTA"
624
625    #def get_mobility_mu_dataframe(self, eh=0, component='xx', itemp=0, spin=0, **kwargs):
626
627    @add_fig_kwargs
628    def plot_mobility_kconv(self, eh=0, component='xx', itemp=0, spin=0, fontsize=14, ax=None, **kwargs):
629        """
630        Plot the convergence of the mobility as a function of the number of k-points.
631
632        Args:
633            eh: 0 for electrons, 1 for holes.
634            component: Cartesian component to plot ('xx', 'xy', ...)
635            itemp: temperature index.
636            spin: Spin index.
637            fontsize: fontsize for legends and titles
638            ax: |matplotlib-Axes| or None if a new figure should be created.
639
640        Returns: |matplotlib-Figure|
641        """
642        ax, fig, plt = get_ax_fig_plt(ax=ax)
643        ax.grid(True)
644        i, j = abu.s2itup(component)
645        irta = 0
646
647        res, temps = []
648        for ncfile in self.abifiles:
649            #kptrlattx, kptrlatty, kptrlattz = ncfile.ngkpt
650            kptrlatt = ncfile.reader.read_value("kptrlatt")
651            kptrlattx = kptrlatt[0, 0]
652            kptrlatty = kptrlatt[1, 1]
653            kptrlattz = kptrlatt[2, 2]
654            # nctkarr_t('mobility_mu',"dp", "three, three, two, ntemp, nsppol, nrta")]
655            mobility = ncfile.reader.read_variable("mobility_mu")[irta, spin, itemp, eh, j, i]
656            #print(mobility)
657            res.append([kptrlattx, mobility])
658            temps.append(ncfile.tmesh[itemp])
659
660        res.sort(key=lambda t: t[0])
661        res = np.array(res)
662        #print(res)
663
664        size = 14
665        ylabel = r"%s mobility (cm$^2$/(V$\cdot$s))" % {0: "Electron", 1: "Hole"}[eh]
666        ax.set_ylabel(ylabel, size=size)
667
668        #if "title" not in kwargs:
669        #    title = r"$\frac{1}{N_k} \sum_{nk} \delta(\epsilon - \epsilon_{nk})$"
670        #    ax.set_title(title, fontsize=fontsize)
671
672        from fractions import Fraction
673        ratio1 = Fraction(kptrlatty, kptrlattx)
674        ratio2 = Fraction(kptrlattz, kptrlattx)
675        text1 = '' if ratio1.numerator == ratio1.denominator else \
676                r'$\frac{{{0}}}{{{1}}}$'.format(ratio1.numerator, ratio1.denominator)
677        text2 = '' if ratio2.numerator == ratio2.denominator else \
678                r'$\frac{{{0}}}{{{1}}}$'.format(ratio2.numerator, ratio2.denominator)
679
680        ax.set_xlabel(r'Homogeneous $N_k \times$ ' + text1 + r'$N_k \times$ ' + text2 + r'$N_k$ $\mathbf{k}$-point grid',
681                      size=size)
682
683        ax.plot(res[:,0], res[:,1], **kwargs)
684        ax.legend(loc="best", shadow=True, fontsize=fontsize)
685
686        return fig
687
688    @lazy_property
689    def assume_gap(self):
690        """True if we are dealing with a semiconductor. More precisely if all(sigma_erange) > 0."""
691        return all(abifile.assume_gap for abifile in self.abifiles)
692
693    @lazy_property
694    def all_have_ibte(self):
695        """True if all files contain IBTE results."""
696        return all(abifile.has_ibte for abifile in self.abifiles)
697
698    def get_same_tmesh(self):
699        """
700        Check whether all files have the same T-mesh. Return common tmesh else raise RuntimeError.
701        """
702        for i in range(len(self)):
703            if not np.all(self.abifiles[0].tmesh == self.abifiles[i].tmesh):
704                raise RuntimeError("Found different T-mesh in RTA files.")
705
706        return self.abifiles[0].tmesh
707
708    #@add_fig_kwargs
709    #def plot_mobility_erange_conv(self, eh=0, component='xx', itemp=0, spin=0, fontsize=8, ax=None, **kwargs):
710
711    #@add_fig_kwargs
712    #def plot_transport_tensors_mu_kconv(self, eh=0, component='xx', itemp=0, spin=0, fontsize=8, ax=None, **kwargs):
713
714    #@add_fig_kwargs
715    #def plot_transport_tensors_mu_kconv(self, eh=0, component='xx', itemp=0, spin=0, fontsize=8, ax=None, **kwargs):
716
717    def plot_ibte_vs_rta_rho(self, component="xx", fontsize=8, **kwargs):
718        """
719        """
720        nrows = 1 # xx
721        ncols = len(self) # SERTA, MRTA, IBTE
722        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
723                                                sharex=True, sharey=True, squeeze=False)
724        ax_list = ax_list.ravel()
725
726        for abifile, ax in zip(self.abifiles, ax_list):
727            abifile.plot_ibte_vs_rta_rho(component="xx", fontsize=fontsize, ax=ax, show=False)
728
729        return fig
730
731    def plot_ibte_mrta_serta_conv(self,  what="resistivity", fontsize=8, **kwargs):
732        """
733        """
734        #num_plots, ncols, nrows, what_list = x2_grid(what_list)
735        nrows = 1 # xx
736        ncols = 3 # SERTA, MRTA, IBTE
737        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
738                                                sharex=True, sharey=True, squeeze=False)
739        ax_list = ax_list.ravel()
740        # don't show the last ax if numeb is odd.
741        #if num_plots % ncols != 0: ax_list[-1].axis("off")
742
743        i = j = 0
744        from collections import defaultdict
745        data = defaultdict(list)
746        for abifile in self.abifiles:
747            rta_vals = abifile.reader.read_value("resistivity")
748            data["serta"].append(rta_vals[0, :, j, i])
749            data["mrta"].append(rta_vals[1, :, j, i])
750            ibte_vals = abifile.reader.read_value("ibte_rho")
751            data["ibte"].append(ibte_vals[:, j, i])
752
753        tmesh = self.get_same_tmesh()
754        keys = ["serta", "mrta", "ibte"]
755        for ix, (key, ax) in enumerate(zip(keys, ax_list)):
756            ax.grid(True)
757            ax.set_title(key.upper(), fontsize=fontsize)
758            for ifile, ys in enumerate(data[key]):
759                ax.plot(tmesh, ys, marker="o", label=self.labels[ifile])
760            ax.set_xlabel("Temperature (K)")
761            if ix == 0:
762                ax.set_ylabel(r"Resistivity ($\mu\Omega\;cm$)")
763                ax.legend(loc="best", shadow=True, fontsize=fontsize)
764
765        return fig
766
767    def yield_figs(self, **kwargs):  # pragma: no cover
768        """
769        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
770        Used in abiview.py to get a quick look at the results.
771        """
772        #yield self.plot_lattice_convergence(show=False)
773        #if self.all_have_ibte:
774        yield self.plot_ibte_mrta_serta_conv(show=False)
775        yield self.plot_ibte_vs_rta_rho(show=False)
776        #self.plot_mobility_kconv(eh=0, component='xx', itemp=0, spin=0, fontsize=14, ax=None, **kwargs):
777
778    #def get_panel(self):
779    #    """
780    #    Build panel with widgets to interact with the |RtaRobot| either in a notebook or in a panel app.
781    #    """
782    #    from abipy.panels.transportfile import TransportRobotPanel
783    #    return TransportRobotPanel(self).get_panel()
784
785    def write_notebook(self, nbpath=None):
786        """
787        Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current
788        working directory is created. Return path to the notebook.
789        """
790        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
791
792        args = [(l, f.filepath) for l, f in self.items()]
793        nb.cells.extend([
794            #nbv.new_markdown_cell("# This is a markdown cell"),
795            nbv.new_code_cell("robot = abilab.RtaRobot(*%s)\nrobot.trim_paths()\nrobot" % str(args)),
796            #nbv.new_code_cell("ebands_plotter = robot.get_ebands_plotter()"),
797        ])
798
799        # Mixins
800        #nb.cells.extend(self.get_baserobot_code_cells())
801        #nb.cells.extend(self.get_ebands_code_cells())
802
803        return self._write_nb_nbpath(nb, nbpath)
804
805
806if __name__ == "__main__":
807    import sys
808    robot = RtaRobot.from_files(sys.argv[1:])
809    print(robot)
810
811    #import matplotlib.pyplot as plt
812    #plt.figure(0, figsize=(14,9))
813    #plt.tick_params(labelsize=14)
814    #ax = plt.gca()
815
816    robot.plot_mobility_kconv(ax=None, color='k', label=r'$N_{{q_{{x,y,z}}}}$ = $N_{{k_{{x,y,z}}}}$')
817
818    #fileslist = ['conv_fine/k27x27x27/q27x27x27/Sio_DS1_TRANSPORT.nc',
819    #             'conv_fine/k30x30x30/q30x30x30/Sio_DS1_TRANSPORT.nc',
820    #             'conv_fine/k144x144x144/q144x144x144/Sio_DS1_TRANSPORT.nc',]
821
822    #plot_mobility_kconv(ax, fileslist, color='k', marker='o', label=r'$N_{{q_{{x,y,z}}}}$ = $N_{{k_{{x,y,z}}}}$')
823
824    #fileslist = ['conv_fine/k27x27x27/q54x54x54/Sio_DS1_TRANSPORT.nc',
825    #             'conv_fine/k66x66x66/q132x132x132/Sio_DS1_TRANSPORT.nc',
826    #             'conv_fine/k72x72x72/q144x144x144/Sio_DS1_TRANSPORT.nc']
827
828    #plot_mobility_kconv(ax, fileslist, color='r', marker='x', label=r'$N_{{q_{{x,y,z}}}}$ = $2 N_{{k_{{x,y,z}}}}$')
829
830    #plt.legend(loc='best',fontsize=14)
831    #plt.show()
832