1"""
2Interface to the GKQ.nc file storing the e-ph matrix elements
3in the atomic representation (idir, ipert) for a single q-point.
4This file is produced by the eph code with eph_task -4.
5To analyze the e-ph scattering potentials, use v1qavg and eph_task 15 or -15
6"""
7import numpy as np
8import abipy.core.abinit_units as abu
9
10from collections import OrderedDict
11from monty.string import marquee
12from monty.functools import lazy_property
13from abipy.core.kpoints import Kpoint
14from abipy.core.mixins import AbinitNcFile, Has_Header, Has_Structure, Has_ElectronBands, NotebookWriter
15from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt
16from abipy.tools import duck
17from abipy.abio.robots import Robot
18from abipy.electrons.ebands import ElectronsReader, RobotWithEbands
19from abipy.eph.common import glr_frohlich, EPH_WTOL
20
21
22class GkqFile(AbinitNcFile, Has_Header, Has_Structure, Has_ElectronBands, NotebookWriter):
23
24    @classmethod
25    def from_file(cls, filepath):
26        """Initialize the object from a netcdf_ file."""
27        return cls(filepath)
28
29    def __init__(self, filepath):
30        super().__init__(filepath)
31        self.reader = GkqReader(filepath)
32
33    def __str__(self):
34        """String representation."""
35        return self.to_string()
36
37    def to_string(self, verbose=0):
38        """String representation."""
39        lines = []; app = lines.append
40
41        app(marquee("File Info", mark="="))
42        app(self.filestat(as_string=True))
43        app("")
44        app(self.structure.to_string(verbose=verbose, title="Structure"))
45        app("")
46        app(self.ebands.to_string(with_structure=False, verbose=verbose, title="Electronic Bands"))
47        app("qpoint: %s" % str(self.qpoint))
48        app("Macroscopic dielectric tensor in Cartesian coordinates")
49        app(str(self.epsinf_cart))
50        app("")
51        app("Born effective charges in Cartesian coordinates:")
52        for i, (site, bec) in enumerate(zip(self.structure, self.becs_cart)):
53            app("[%d]: %s" % (i, repr(site)))
54            app(str(bec))
55            app("")
56
57        app(r"Fulfillment of charge neutrality, F_{ij} = \sum_{atom} Z^*_{ij,atom}")
58        f = np.sum(self.becs_cart, axis=0)
59        app(str(f) + "\n")
60
61        return "\n".join(lines)
62
63    def close(self):
64        self.reader.close()
65
66    @lazy_property
67    def ebands(self):
68        """|ElectronBands| object."""
69        return self.reader.read_ebands()
70
71    @lazy_property
72    def structure(self):
73        """|Structure| object."""
74        return self.ebands.structure
75
76    @lazy_property
77    def uses_interpolated_dvdb(self):
78        """True if the matrix elements have been computed with an interpolated potential."""
79        return int(self.reader.read_value("interpolated")) == 1
80
81    @lazy_property
82    def params(self):
83        """Dict with parameters that might be subject to convergence studies."""
84        od = self.get_ebands_params()
85        return od
86
87    @lazy_property
88    def qpoint(self):
89        """Q-point object."""
90        return Kpoint(self.reader.read_value('qpoint'), self.structure.reciprocal_lattice)
91
92    @lazy_property
93    def phfreqs_ha(self):
94        """(3 * natom) array with phonon frequencies in Ha."""
95        return self.reader.read_value("phfreqs")
96
97    @lazy_property
98    def phdispl_cart_bohr(self):
99        """(natom3_nu, natom3) complex array with phonon displacement in cartesian coordinates in Bohr."""
100        return self.reader.read_value("phdispl_cart", cmode="c")
101
102    @lazy_property
103    def phdispl_red(self):
104        """(natom3_nu, natom3) complex array with phonon displacement in reduced coordinates."""
105        return self.reader.read_value("phdispl_red", cmode="c")
106
107    @lazy_property
108    def becs_cart(self):
109        """(natom, 3, 3) array with the Born effective charges in Cartesian coordinates."""
110        return self.reader.read_value("becs_cart").transpose(0, 2, 1).copy()
111
112    @lazy_property
113    def epsinf_cart(self):
114        """(3, 3) array with electronic macroscopic dielectric tensor in Cartesian coordinates."""
115        return self.reader.read_value("emacro_cart").T.copy()
116
117    @lazy_property
118    def eigens_kq(self):
119        """(spin, nkpt, mband) array with eigenvalues on the k+q grid in eV."""
120        return self.reader.read_value("eigenvalues_kq") * abu.Ha_eV
121
122    def read_all_gkq(self, mode="phonon"):
123        """
124        Read all eph matrix stored on disk.
125
126        Args:
127            mode: "phonon" if for eph matrix elements in phonon representation,
128                  "atom" for perturbation along (idir, iatom).
129
130        Return: (nsppol, nkpt, 3*natom, mband, mband) complex array.
131        """
132        if mode not in ("atom", "phonon"):
133            raise ValueError("Invalid mode: %s" % mode)
134
135        # Read e-ph matrix element in the atomic representation (idir, ipert)
136        # Fortran array on disk has shape:
137        # nctkarr_t('gkq', "dp", &
138        # 'complex, max_number_of_states, max_number_of_states, number_of_phonon_modes, number_of_kpoints, number_of_spins')
139        gkq_atm = self.reader.read_value("gkq", cmode="c")
140        if mode == "atom": return gkq_atm
141
142        # Convert from atomic to phonon representation.
143        # May use np.einsum for better efficiency but oh well!
144        nband = gkq_atm.shape[-1]
145        nb2 = nband ** 2
146        assert nband == gkq_atm.shape[-2] and nband == self.ebands.nband
147        natom = len(self.structure)
148        natom3 = natom * 3
149        phfreqs_ha, phdispl_red = self.phfreqs_ha, self.phdispl_red
150        gkq_nu = np.empty_like(gkq_atm)
151        cwork = np.empty((natom3, nb2), dtype=np.complex)
152        for spin in range(self.ebands.nsppol):
153            for ik in range(self.ebands.nkpt):
154                g = np.reshape(gkq_atm[spin, ik], (-1, nb2))
155                for nu in range(natom3):
156                    if phfreqs_ha[nu] > EPH_WTOL:
157                        cwork[nu] = np.dot(phdispl_red[nu], g) / np.sqrt(2.0 * phfreqs_ha[nu])
158                    else:
159                        cwork[nu] = 0.0
160                gkq_nu[spin, ik] = np.reshape(cwork, (natom3, nband, nband))
161
162        return gkq_nu
163
164    #def get_averaged_gkq(self, spin, ik, band_k, band_kq, tol_deg=1e-3):
165    #    natom3 = len(self.structure) * 3
166    #    e_k = self.ebands.eigens[spin, ik, band_k])
167    #    e_kq = self.eigens_kq[spin, ik, band_kq]
168    #    e_k, ndeg_k, bids_k = _find_deg(spin, ik, self.ebands.eigens)
169    #    e_kq, ndeg_kq, bids_kq = _find_deg(spin, ik, self.eigens_kq)
170    #
171    #    gkq2_nu = np.zeros(natom3))
172    #    ncvar = abifile.reader.read_variable("gkq")
173    #    for ib_k in bids_k:
174    #
175    #       for ib_kq in bids_kq:
176    #           gkq_atm = ncvar[spin, ik, :, ib_k, ib_kq]
177    #           gkq_atm = gkq_atm[:, 0] + 1j * gkq_atm[:, 1]
178    #
179    #           # Transform the gkk matrix elements from (atom, red_direction) basis to phonon-mode basis.
180    #           gkq_nu = np.zeros(natom3), dtype=np.complex)
181    #           for nu in range(natom3):
182    #               if self.phfreqs_ha[nu] < eph_wtol: continue
183    #               gkq_nu[nu] = np.dot(self.phdispl_red[nu], gkq_atm) / np.sqrt(2.0 * self.phfreqs_ha[nu])
184    #        gkq2_nu += np.abs(gkq_nu[nu]) ** 2
185    #
186    #    return np.sqrt(gkq2_nu)
187
188    @add_fig_kwargs
189    def plot(self, mode="phonon", with_glr=True, fontsize=8, colormap="viridis", sharey=True, **kwargs):
190        """
191        Plot the gkq matrix elements for a given q-point.
192
193        Args:
194            mode: "phonon" to plot eph matrix elements in the phonon representation,
195                  "atom" for atomic representation.
196            with_glr: True to plot the long-range component estimated from Verdi's model.
197            fontsize: Label and title fontsize.
198            colormap: matplotlib colormap
199            sharey: True if yaxis should be shared among axes.
200
201        Return: |matplotlib-Figure|
202        """
203        gkq = np.abs(self.read_all_gkq(mode=mode))
204        if mode == "phonon": gkq *= abu.Ha_meV
205
206        # Compute e_{k+q} - e_k for all possible (b, b')
207        ediffs = np.empty_like(gkq)
208        for spin in range(self.ebands.nsppol):
209            for ik in range(self.ebands.nkpt):
210                for ib_kq in range(self.ebands.mband):
211                    for ib_k in range(self.ebands.mband):
212                        ed = np.abs(self.eigens_kq[spin, ik, ib_kq] - self.ebands.eigens[spin, ik, ib_k])
213                        ediffs[spin, ik, :, ib_k, ib_kq] = ed
214
215        if with_glr and mode == "phonon":
216            # Add horizontal bar with matrix elements computed from Verdi's model (only G = 0, \delta_nm in bands).
217            dcart_bohr = self.phdispl_cart_bohr
218            #dcart_bohr = self.reader.read_value("phdispl_cart_qvers", cmode="c").real
219            gkq_lr = glr_frohlich(self.qpoint, self.becs_cart, self.epsinf_cart,
220                                  dcart_bohr, self.phfreqs_ha, self.structure)
221            # self.phdispl_cart_bohr, self.phfreqs_ha, self.structure)
222            gkq2_lr = np.abs(gkq_lr) * abu.Ha_meV
223
224        natom = len(self.structure)
225        num_plots, ncols, nrows = 3 * natom, 3, natom
226        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
227                                                sharex=True, sharey=sharey, squeeze=False)
228        ax_list = ax_list.ravel()
229        cmap = plt.get_cmap(colormap)
230
231        for nu, ax in enumerate(ax_list):
232            idir = nu % 3
233            iat = (nu - idir) // 3
234            data, c = gkq[:, :, nu, :, :].ravel(), ediffs[:,:,nu,:,:].ravel()
235            # Filter items according to ediff
236            index = c <= 1.2 * self.phfreqs_ha.max() * abu.Ha_eV
237            data, c = data[index], c[index]
238            sc = ax.scatter(np.arange(len(data)), data, alpha=0.9, s=30, c=c, cmap=cmap)
239                            #facecolors='none', edgecolors='orange')
240            plt.colorbar(sc, ax=ax)
241
242            ax.grid(True)
243            if iat == natom - 1:
244                ax.set_xlabel("Matrix element index")
245            if idir == 0:
246                ylabel = r"$|g^{atm}_{\bf q}|$" if mode == "atom" else r"$|g_{\bf q}|$ (meV)"
247                ax.set_ylabel(ylabel)
248
249            ax.set_title(r"$\nu$: %d, $\omega_{{\bf q}\nu}$ = %.2E (meV)" %
250                         (nu, self.phfreqs_ha[nu] * abu.Ha_meV), fontsize=fontsize)
251
252            if with_glr:
253                ax.axhline(gkq2_lr[nu], color='k', linestyle='dashed', linewidth=2)
254
255        fig.suptitle("qpoint: %s" % repr(self.qpoint), fontsize=fontsize)
256        return fig
257
258    @add_fig_kwargs
259    def plot_diff_with_other(self, other, mode="phonon", ax_list=None, labels=None, fontsize=8, **kwargs):
260        """
261        Produce scatter plot and histogram to compare the gkq matrix elements stored in two files.
262
263            other: other GkqFile instance.
264            mode: "phonon" to plot eph matrix elements in the phonon representation,
265                  "atom" for atomic representation.
266            ax_list: List with 2 matplotlib axis. None if new ax_list should be created.
267            labels: Labels associated to self and other
268            fontsize: Label and title fontsize.
269
270        Return: |matplotlib-Figure|
271        """
272        if self.qpoint != other.qpoint:
273            raise ValueError("Found different q-points: %s and %s" % (self.qpoint, other.qpoint))
274
275        if labels is None:
276            labels = ["this (interpolated: %s)" % self.uses_interpolated_dvdb,
277                      "other (interpolated: %s)" % other.uses_interpolated_dvdb]
278
279        this_gkq = np.abs(self.read_all_gkq(mode=mode))
280        other_gkq = np.abs(other.read_all_gkq(mode=mode))
281        if mode == "phonon":
282            this_gkq *= abu.Ha_meV
283            other_gkq *= abu.Ha_meV
284
285        absdiff_gkq = np.abs(this_gkq - other_gkq)
286
287        stats = OrderedDict([
288            ("min", absdiff_gkq.min()),
289            ("max", absdiff_gkq.max()),
290            ("mean", absdiff_gkq.mean()),
291            ("std", absdiff_gkq.std()),
292        ])
293
294        num_plots, ncols, nrows = 2, 2, 1
295        ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
296                                                sharex=False, sharey=False, squeeze=False)
297        ax_list = ax_list.ravel()
298
299        # Downsample datasets. Show only points with error > threshold.
300        ntot = absdiff_gkq.size
301        threshold = stats["mean"] + stats["std"]
302        data = this_gkq[absdiff_gkq > threshold].ravel()
303        nshown = len(data)
304        xs = np.arange(len(data))
305
306        ax = ax_list[0]
307        ax.scatter(xs, data, alpha=0.9, s=30, label=labels[0],
308                   facecolors='none', edgecolors='orange')
309
310        data = other_gkq[absdiff_gkq > threshold].ravel()
311        ax.scatter(xs, data, alpha=0.3, s=10, marker="x", label=labels[1],
312                   facecolors="g", edgecolors="none")
313
314        ax.grid(True)
315        ax.set_xlabel("Matrix element index")
316        ylabel = r"$|g^{atm}_{\bf q}|$" if mode == "atom" else r"$|g_{\bf q}|$ (meV)"
317        ax.set_ylabel(ylabel)
318        ax.set_title(r"qpt: %s, $\Delta$ > %.1E (%.1f %%)" % (
319                     repr(self.qpoint), threshold, 100 * nshown / ntot),
320                     fontsize=fontsize)
321        ax.legend(loc="best", fontsize=fontsize, shadow=True)
322
323        ax = ax_list[1]
324        ax.hist(absdiff_gkq.ravel(), facecolor='g', alpha=0.75)
325        ax.grid(True)
326        ax.set_xlabel("Absolute Error" if mode == "atom" else "Absolute Error (meV)")
327        ax.set_ylabel("Count")
328
329        ax.axvline(stats["mean"], color='k', linestyle='dashed', linewidth=1)
330        _, max_ = ax.get_ylim()
331        ax.text(0.7, 0.7,  "\n".join("%s = %.1E" % item for item in stats.items()),
332                fontsize=fontsize, horizontalalignment='center', verticalalignment='center',
333                transform=ax.transAxes)
334
335        return fig
336
337    def yield_figs(self, **kwargs):  # pragma: no cover
338        """
339        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
340        Used in abiview.py to get a quick look at the results.
341        """
342        yield self.plot()
343
344    def write_notebook(self, nbpath=None):
345        """
346        Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current
347        working directory is created. Return path to the notebook.
348        """
349        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
350
351        nb.cells.extend([
352            nbv.new_code_cell("gkq = abilab.abiopen('%s')" % self.filepath),
353            nbv.new_code_cell("print(gkq)"),
354            nbv.new_code_cell("gkq.ebands.plot();"),
355            nbv.new_code_cell("gkq.epsinf_cart;"),
356            nbv.new_code_cell("gkq.becs_cart;"),
357            nbv.new_code_cell("""
358              #with abilab.abiopen('other_GKQ.nc') as other:
359              #     gkq.plot_diff_with_other(other);
360            """)
361        ])
362
363        return self._write_nb_nbpath(nb, nbpath)
364
365
366class GkqReader(ElectronsReader):
367    """
368    This object reads the results stored in the GKQ file produced by ABINIT.
369    It provides helper function to access the most important quantities.
370
371    .. rubric:: Inheritance Diagram
372    .. inheritance-diagram:: GkqReader
373    """
374
375
376class GkqRobot(Robot, RobotWithEbands):
377    """
378    This robot analyzes the results contained in multiple GKQ.nc files.
379
380    .. rubric:: Inheritance Diagram
381    .. inheritance-diagram:: GkqRobot
382    """
383    EXT = "GKQ"
384
385    @lazy_property
386    def kpoints(self):
387        # Consistency check: kmesh should be the same in each file.
388        ref_kpoints = self.abifiles[0].ebands.kpoints
389        for i, abifile in enumerate(self.abifiles):
390            if i == 0: continue
391            if abifile.kpoints != ref_kpoints:
392                for k1, k2 in zip(ref_kpoints, abifile.kpoints):
393                    print("k1:", k1, "--- k2:", k2)
394                raise ValueError("Found different list of kpoints in %s" % str(abifile.filepath))
395        return ref_kpoints
396
397    def _check_qpoints_equal(self):
398        """Raises ValueError if different `qpoint` in files."""
399        ref_qpoint = self.abifiles[0].qpoint
400        for i, abifile in enumerate(self.abifiles):
401            if i == 0: continue
402            if abifile.qpoint != ref_qpoint:
403                raise ValueError("Found different qpoint in %s" % str(abifile.filepath))
404
405    @add_fig_kwargs
406    def plot_gkq2_qpath(self, band_kq, band_k, kpoint=0, with_glr=False, qdamp=None, nu_list=None, # spherical_average=False,
407                        ax=None, fontsize=8, eph_wtol=EPH_WTOL, **kwargs):
408        r"""
409        Plot the magnitude of the electron-phonon matrix elements <k+q, band_kq| Delta_{q\nu} V |k, band_k>
410        for a given set of (band_kq, band, k) as a function of the q-point.
411
412        Args:
413            band_ks: Band index of the k+q states (starts at 0)
414            band_k: Band index of the k state (starts at 0)
415            kpoint: |Kpoint| object or index.
416            with_glr: True to plot the long-range component estimated from Verdi's model.
417            qdamp:
418            nu_list: List of phonons modes to be selected (starts at 0). None to select all modes.
419            ax: |matplotlib-Axes| or None if a new figure should be created.
420            fontsize: Label and title fontsize.
421
422        Return: |matplotlib-Figure|
423        """
424        if duck.is_intlike(kpoint):
425            ik = kpoint
426            kpoint = self.kpoints[ik]
427        else:
428            kpoint = Kpoint.as_kpoint(kpoint, self.abifiles[0].structure.reciprocal_lattice)
429            ik = self.kpoints.index(kpoint)
430
431        # Assume abifiles are already ordered according to q-path.
432        xs = list(range(len(self.abifiles)))
433        natom3 = len(self.abifiles[0].structure) * 3
434        nsppol = self.abifiles[0].nsppol
435        nqpt = len(self.abifiles)
436        gkq_snuq = np.empty((nsppol, natom3, nqpt), dtype=np.complex)
437        if with_glr: gkq_lr = np.empty((nsppol, natom3, nqpt), dtype=np.complex)
438
439        # TODO: Should take into account possible degeneracies in k and kq...
440        xticks, xlabels = [], []
441        for iq, abifile in enumerate(self.abifiles):
442            qpoint = abifile.qpoint
443            #d3q_fact = one if not spherical_average else np.sqrt(4 * np.pi) * qpoint.norm
444
445            name = qpoint.name if qpoint.name is not None else abifile.structure.findname_in_hsym_stars(qpoint)
446            if qpoint.name is not None:
447                xticks.append(iq)
448                xlabels.append(name)
449
450            phfreqs_ha, phdispl_red = abifile.phfreqs_ha, abifile.phdispl_red
451            ncvar = abifile.reader.read_variable("gkq")
452            for spin in range(nsppol):
453                gkq_atm = ncvar[spin, ik, :, band_k, band_kq]
454                gkq_atm = gkq_atm[:, 0] + 1j * gkq_atm[:, 1]
455
456                # Transform the gkk matrix elements from (atom, red_direction) basis to phonon-mode basis.
457                gkq_snuq[spin, :, iq] = 0.0
458                for nu in range(natom3):
459                    if phfreqs_ha[nu] < eph_wtol: continue
460                    gkq_snuq[spin, nu, iq] = np.dot(phdispl_red[nu], gkq_atm) / np.sqrt(2.0 * phfreqs_ha[nu])
461
462            if with_glr:
463                # Compute long range part with (simplified) generalized Frohlich model.
464                gkq_lr[spin, :, iq] = glr_frohlich(qpoint, abifile.becs_cart, abifile.epsinf_cart,
465                                                   abifile.phdispl_cart_bohr, phfreqs_ha, abifile.structure, qdamp=qdamp)
466
467        ax, fig, plt = get_ax_fig_plt(ax=ax)
468
469        nu_list = list(range(natom3)) if nu_list is None else list(nu_list)
470        for spin in range(nsppol):
471            for nu in nu_list:
472                ys = np.abs(gkq_snuq[spin, nu]) * abu.Ha_meV
473                pre_label = kwargs.pop("pre_label",r"$g_{\bf q}$")
474                if nsppol == 1: label = r"%s $\nu$: %s" % (pre_label, nu)
475                if nsppol == 2: label = r"%s $\nu$: %s, spin: %s" % (pre_label, nu, spin)
476                ax.plot(xs, ys, linestyle="--", label=label)
477                if with_glr:
478                    # Plot model with G = 0 and delta_nn'
479                    ys = np.abs(gkq_lr[spin, nu]) * abu.Ha_meV
480                    label = r"$g_{\bf q}^{\mathrm{lr0}}$ $\nu$: %s" % nu
481                    ax.plot(xs, ys, linestyle="", marker="o", label=label)
482
483        ax.grid(True)
484        ax.set_xlabel("Wave Vector")
485        ax.set_ylabel(r"$|g_{\bf q}|$ (meV)")
486        if xticks:
487            ax.set_xticks(xticks, minor=False)
488            ax.set_xticklabels(xlabels, fontdict=None, minor=False, size=kwargs.pop("klabel_size", "large"))
489
490        ax.legend(loc="best", fontsize=fontsize, shadow=True)
491        title = r"$band_{{\bf k} + {\bf q}: %s, band_{\bf{k}}: %s, kpoint: %s" % (band_kq, band_k, repr(kpoint))
492        ax.set_title(title, fontsize=fontsize)
493
494        return fig
495
496    #@add_fig_kwargs
497    #def plot_gkq2_qpath_with_robots(self, other_robots, all_labels, band_kq, band_k, kpoint=0, ax=None, **kwargs):
498    #    if not isinstance(other_robots, (list, tuple)):
499    #        raise TypeError("other_robots should be a list. Received: %s" % type(other_robots))
500    #    if len(all_labels) /= 1 + len(other_robots):
501    #        raise ValueError("len(all_labels) should be equal to 1 + len(other_robots)")
502
503    #    ax, fig, plt = get_ax_fig_plt(ax=ax)
504    #    #self.plot_gkq2_qpath(self, band_kq, band_k, kpoint=kpoint,
505    #    #                with_glr=False, qdamp=None, nu_list=None, # spherical_average=False,
506    #    #                ax=ax, fontsize=8, eph_wtol=EPH_WTOL, **kwargs):
507
508    #    return fig
509
510    @add_fig_kwargs
511    def plot_gkq2_diff(self, iref=0, **kwargs):
512        """
513        Wraps gkq.plot_diff_with_other
514        Produce scatter and histogram plot to compare the gkq matrix elements stored in all the files
515        contained in the robot. Assume all files have the same q-point. Compare the `iref` file with others.
516        kwargs are passed to `plot_diff_with_other`.
517        """
518        if len(self) <= 1: return None
519        self._check_qpoints_equal()
520
521        ncols, nrows = 2, len(self) - 1
522        num_plots = ncols * nrows
523        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
524                                               sharex=False, sharey=False, squeeze=False)
525
526        ref_gkq, ref_label = self.abifiles[iref], self.labels[iref]
527        cnt = -1
528        for ifile, (other_label, other_gkq) in enumerate(zip(self.labels, self.abifiles)):
529            if ifile == iref: continue
530            cnt += 1
531            labels = [ref_label, other_label]
532            ref_gkq.plot_diff_with_other(other_gkq, ax_list=ax_mat[cnt], labels=labels, show=False, **kwargs)
533
534        return fig
535
536    def yield_figs(self, **kwargs): # pragma: no cover
537        """
538        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
539        Used in abiview.py to get a quick look at the results.
540        """
541        for fig in self.get_ebands_plotter().yield_figs(): yield fig
542
543    def write_notebook(self, nbpath=None):
544        """
545        Write a jupyter_ notebook to `nbpath`. If nbpath is None, a temporary file in the current
546        working directory is created. Return path to the notebook.
547        """
548        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
549
550        args = [(l, f.filepath) for l, f in self.items()]
551        nb.cells.extend([
552            #nbv.new_markdown_cell("# This is a markdown cell"),
553            nbv.new_code_cell("robot = abilab.GkqRobot(*%s)\nrobot.trim_paths()\nrobot" % str(args)),
554            nbv.new_code_cell("# robot.plot_gkq2_diff();"),
555            nbv.new_code_cell("# robot.plot_gkq2_qpath(band_kq=0, band_k=0, kpoint=0, with_glr=True, qdamp=None);")
556        ])
557
558        # Mixins
559        nb.cells.extend(self.get_baserobot_code_cells())
560        nb.cells.extend(self.get_ebands_code_cells())
561
562        return self._write_nb_nbpath(nb, nbpath)
563