1# coding: utf-8
2"""Classes to analyse electronic structures."""
3import os
4import copy
5import itertools
6import json
7import warnings
8import tempfile
9import pickle
10import numpy as np
11import pandas as pd
12import pymatgen.core.units as units
13
14from collections import OrderedDict, namedtuple
15from collections.abc import Iterable
16from monty.string import is_string, list_strings, marquee
17from monty.termcolor import cprint
18from monty.json import MontyEncoder
19from monty.collections import AttrDict, dict2namedtuple
20from monty.functools import lazy_property
21from monty.bisect import find_le, find_gt
22from pymatgen.util.serialization import pmg_serialize
23from pymatgen.electronic_structure.core import Spin as PmgSpin
24from abipy.core.func1d import Function1D
25from abipy.core.mixins import Has_Structure, NotebookWriter
26from abipy.core.kpoints import (Kpoint, KpointList, Kpath, IrredZone, KSamplingInfo, KpointsReaderMixin,
27    Ktables, has_timrev_from_kptopt, map_grid2ibz) #, kmesh_from_mpdivs)
28from abipy.core.structure import Structure
29from abipy.iotools import ETSF_Reader
30from abipy.tools import duck
31from abipy.tools.numtools import gaussian
32from abipy.tools.plotting import (set_axlims, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt,
33    get_ax3d_fig_plt, rotate_ticklabels, set_visible, plot_unit_cell, set_ax_xylabels)
34
35
36__all__ = [
37    "ElectronBands",
38    "ElectronDos",
39    "dataframe_from_ebands",
40    "ElectronBandsPlotter",
41    "ElectronDosPlotter",
42]
43
44
45class Electron(namedtuple("Electron", "spin kpoint band eig occ kidx")):
46    """
47    Single-particle state.
48
49    .. Attributes:
50
51        spin: spin index (C convention, i.e >= 0)
52        kpoint: |Kpoint| object.
53        band: band index. (C convention, i.e >= 0)
54        eig: KS eigenvalue.
55        occ: Occupation factor.
56        kidx: Index of the k-point in the initial array.
57
58    .. note::
59
60        Energies are in eV.
61    """
62    def __eq__(self, other):
63        if other is None: return False
64        return self.spin == other.spin and self.kpoint == other.kpoint and self.band == other.band
65
66    def __ne__(self, other):
67        return not (self == other)
68
69    def __str__(self):
70        return "spin: %d, kpt: %s, band: %d, eig: %.3f, occ: %.3f" % (
71            self.spin, self.kpoint, self.band, self.eig, self.occ)
72
73    @property
74    def skb(self):
75        """Tuple with (spin, kpoint, band)."""
76        return self.spin, self.kpoint, self.band
77
78    def copy(self):
79        """Shallow copy."""
80        return self.__class__(**{f: copy.copy(getattr(self, f)) for f in self._fields})
81
82    @classmethod
83    def get_fields(cls, exclude=()):
84        fields = list(cls._fields)
85        for e in exclude:
86            fields.remove(e)
87
88        return tuple(fields)
89
90    def as_dict(self):
91        """Convert self into a dict."""
92        return super()._asdict()
93
94    def to_strdict(self, fmt=None):
95        """Ordered dictionary mapping fields --> strings."""
96        d = self.as_dict()
97        for k, v in d.items():
98            if np.iscomplexobj(v):
99                if abs(v.imag) < 1.e-3:
100                    d[k] = "%.2f" % v.real
101                else:
102                    d[k] = "%.2f%+.2fj" % (v.real, v.imag)
103            elif isinstance(v, int):
104                d[k] = "%d" % v
105            else:
106                try:
107                    d[k] = "%.2f" % v
108                except TypeError as exc:
109                    #print("k", k, str(exc))
110                    d[k] = str(v)
111        return d
112
113    @property
114    def tips(self):
115        """Dictionary with the description of the fields."""
116        return self.__class__.TIPS()
117
118    @classmethod
119    def TIPS(cls):
120        """
121        Class method that returns a dictionary with the description of the fields.
122        The string are extracted from the class doc string.
123        """
124        try:
125            return cls._TIPS
126
127        except AttributeError:
128            # Parse the doc string.
129            cls._TIPS = _TIPS = {}
130            lines = cls.__doc__.splitlines()
131
132            for i, line in enumerate(lines):
133                if line.strip().startswith(".. Attributes"):
134                    lines = lines[i+1:]
135                    break
136
137            def num_leadblanks(string):
138                """Returns the number of the leading whitespaces in a string"""
139                return len(string) - len(string.lstrip())
140
141            for field in cls._fields:
142                for i, line in enumerate(lines):
143
144                    if line.strip().startswith(field + ":"):
145                        nblanks = num_leadblanks(line)
146                        desc = []
147                        for s in lines[i+1:]:
148                            if nblanks == num_leadblanks(s) or not s.strip():
149                                break
150                            desc.append(s.lstrip())
151
152                        _TIPS[field] = "\n".join(desc)
153
154            diffset = set(cls._fields) - set(_TIPS.keys())
155            if diffset:
156                raise RuntimeError("The following fields are not documented: %s" % str(diffset))
157
158            return _TIPS
159
160
161class ElectronTransition(object):
162    """
163    This object describes an electronic transition between two single-particle states.
164    """
165    def __init__(self, in_state, out_state, all_kinds=None):
166        """
167        Args:
168            in_state, out_state: Initial and finale state (:class:`Electron` instances).
169            all_kinds: List of tuple. Each tuple gives the index of the k-point of the (initial, final) state.
170                Used to plot e.g. all the optical gaps when there are equivalent k-points along the path.
171        """
172        self.in_state = in_state
173        self.out_state = out_state
174        if all_kinds is None:
175            self.all_kinds = [(self.in_state.kidx, self.out_state.kidx)]
176        else:
177            # Provide default.
178            self.all_kinds = all_kinds
179
180    def __str__(self):
181        return self.to_string()
182
183    def to_string(self, verbose=0):
184        """String representation."""
185        lines = []; app = lines.append
186        app("Energy: %.3f (eV)" % self.energy)
187        app("Initial state: %s" % str(self.in_state))
188        app("Final state:   %s" % str(self.out_state))
189
190        return "\n".join(lines)
191
192    def __eq__(self, other):
193        if other is None: return False
194        return self.in_state == other.in_state and self.out_state == other.out_state
195
196    def __ne__(self, other):
197        return not (self == other)
198
199    @lazy_property
200    def energy(self):
201        """Transition energy in eV."""
202        return self.out_state.eig - self.in_state.eig
203
204    @lazy_property
205    def qpoint(self):
206        """k_final - k_initial"""
207        return self.out_state.kpoint - self.in_state.kpoint
208
209    @lazy_property
210    def is_direct(self):
211        """True if direct transition."""
212        return self.in_state.kpoint == self.out_state.kpoint
213
214
215class Smearing(AttrDict):
216    """
217    Stores data and information about the smearing technique.
218    """
219    _MANDATORY_KEYS = [
220        "scheme",
221        "occopt",
222        "tsmear_ev",
223    ]
224
225    @classmethod
226    def from_dict(cls, d):
227        """
228        Makes Smearing obey the general json interface used in pymatgen for easier serialization.
229        """
230        return cls(**{k: d[k] for k in cls._MANDATORY_KEYS})
231
232    @pmg_serialize
233    def as_dict(self):
234        """
235        Makes Smearing obey the general json interface used in pymatgen for easier serialization.
236        """
237        return self
238
239    def to_json(self):
240        """
241        Returns a JSON_ string representation of the MSONable object.
242        """
243        return json.dumps(self.as_dict(), cls=MontyEncoder)
244
245    @classmethod
246    def as_smearing(cls, obj):
247        """"
248        Convert obj into a Smearing instance.
249        Accepts: Smearing instance, None (if info are not available), Dict-like object.
250        """
251        if isinstance(obj, cls): return obj
252        if obj is None:
253            return cls(scheme=None, occopt=1, tsmear_ev=0.0)
254
255        # Assume dict-like object.
256        try:
257            return cls(**obj)
258        except Exception as exc:
259            raise TypeError("Don't know how to convert %s into Smearing object:\n%s" % (type(obj), str(exc)))
260
261    def __init__(self, *args, **kwargs):
262        super().__init__(*args, **kwargs)
263        for mkey in self._MANDATORY_KEYS:
264            if mkey not in self:
265                raise ValueError("Mandatory key %s must be provided" % str(mkey))
266
267    def __str__(self):
268        return "smearing scheme: %s (occopt %d), tsmear_eV: %.3f" % (self.scheme, self.occopt, self.tsmear_ev)
269
270    @property
271    def has_metallic_scheme(self):
272        """True if we are using a metallic scheme for occupancies."""
273        return self.occopt in [3, 4, 5, 6, 7, 8]
274
275
276class StatParams(namedtuple("StatParams", "mean stdev min max")):
277    """Named tuple with statistical parameters."""
278    def __str__(self):
279        return "mean = %.3f, stdev = %.3f, min = %.3f, max = %.3f (eV)" % (
280            self.mean, self.stdev, self.min, self.max)
281
282
283class ElectronBandsError(Exception):
284    """Exceptions raised by ElectronBands."""
285
286
287class ElectronBands(Has_Structure):
288    """
289    Object storing the electron band structure.
290
291    .. attribute:: fermie
292
293            Fermi level in eV. Note that, if the band structure has been computed
294            with a NSCF run, fermie corresponds to the fermi level obtained
295            in the SCF run that produced the density used for the band structure calculation.
296
297    .. rubric:: Inheritance Diagram
298    .. inheritance-diagram:: ElectronBands
299    """
300    Error = ElectronBandsError
301
302    # FIXME
303    # Increase a bit the value of fermie used in bisection routines to solve the problem mentioned below
304    pad_fermie = 1e-3
305
306    # One should check whether fermie is recomputed at the end of the SCF cyle
307    # I have problems in finding homos/lumos in semiconductors (e.g. Si)
308    # because fermie is slightly smaller than the CBM:
309
310    # fermie 5.59845327874
311    # homos [Electron(spin=0, kpoint=[0.000, 0.000, 0.000], band=1, eig=5.5984532787385985, occ=2.0)]
312    # lumos [Electron(spin=0, kpoint=[0.000, 0.000, 0.000], band=2, eig=5.5984532788661543, occ=2.0)]
313    #
314    # There's also another possible problem if the DEN is computed on a grid that does not contain the CBM (e.g. Gamma)
315    # because the CBM obtained with the NSCF band structure run will be likely above the Ef computed previously.
316
317    @classmethod
318    def from_file(cls, filepath):
319        """
320        Initialize an instance of |ElectronBands| from the netCDF file ``filepath``.
321        """
322        if filepath.endswith(".nc"):
323            with ElectronsReader(filepath) as r:
324                new = r.read_ebands()
325        else:
326            raise NotImplementedError("ElectronBands can only be initialized from nc files")
327
328        assert new.__class__ == cls
329        return new
330
331    @classmethod
332    def from_dict(cls, d):
333        """Reconstruct object from the dictionary in MSONable format produced by as_dict."""
334        d = d.copy()
335        kd = d["kpoints"].copy()
336        kd.pop("@module")
337
338        kpoints_cls = KpointList.subclass_from_name(kd.pop("@class"))
339        kpoints = kpoints_cls.from_dict(kd)
340
341        # Needed to support old dictionaries
342        if "nspden" not in d: d["nspden"] = 1
343        if "nspinor" not in d: d["nspinor"] = 1
344        return cls(Structure.from_dict(d["structure"]), kpoints,
345                   d["eigens"], d["fermie"], d["occfacts"], d["nelect"], d["nspinor"], d["nspden"],
346                   nband_sk=d["nband_sk"], smearing=d["smearing"],
347                   linewidths=d.get("linewidths", None)
348                   )
349
350    @pmg_serialize
351    def as_dict(self):
352        """Return dictionary with JSON serialization."""
353        linewidths = None if not self.has_linewidths else self.linewidths.tolist()
354        return dict(
355            structure=self.structure.as_dict(),
356            kpoints=self.kpoints.as_dict(),
357            eigens=self.eigens.tolist(),
358            fermie=float(self.fermie),
359            occfacts=self.occfacts.tolist(),
360            nelect=float(self.nelect),
361            nspinor=self.nspinor,
362            nspden=self.nspden,
363            nband_sk=self.nband_sk.tolist(),
364            smearing=self.smearing.as_dict(),
365            linewidths=linewidths,
366        )
367
368    @classmethod
369    def as_ebands(cls, obj):
370        """
371        Return an instance of |ElectronBands| from a generic object `obj`.
372        Supports:
373
374            - instances of cls
375            - files (string) that can be open with abiopen and that provide an `ebands` attribute.
376            - objects providing an `ebands` attribute
377        """
378        if isinstance(obj, cls):
379            return obj
380
381        elif is_string(obj):
382            # path?
383            if obj.endswith(".pickle"):
384                with open(obj, "rb") as fh:
385                    return cls.as_ebands(pickle.load(fh))
386
387            if obj.endswith("_EBANDS.nc"):
388                return cls.from_file(obj)
389
390            from abipy.abilab import abiopen, abifile_subclass_from_filename
391            try:
392                _ = abifile_subclass_from_filename(obj)
393                use_abiopen = True
394            except ValueError:
395                # This is needed to treat the case in which we are trying to read ElectronBands
396                # from a nc file that is not known to AbiPy.
397                use_abiopen = False
398
399            if use_abiopen:
400                with abiopen(obj) as abifile:
401                    return abifile.ebands
402            else:
403                return cls.from_file(obj)
404
405        elif hasattr(obj, "ebands"):
406            # object with ebands
407            return obj.ebands
408
409        raise TypeError("Don't know how to extract ebands from object `%s`" % type(obj))
410
411    @classmethod
412    def from_mpid(cls, material_id, api_key=None, endpoint=None,
413                  nelect=None, has_timerev=True, nspinor=1, nspden=None, line_mode=True):
414        """
415        Read bandstructure data corresponding to a materials project ``material_id``.
416        and return Abipy ElectronBands object. Return None if bands are not available.
417
418        Args:
419            material_id (str): Materials Project material_id (a string, e.g., mp-1234).
420            api_key (str): A String API key for accessing the MaterialsProject
421                REST interface. Please apply on the Materials Project website for one.
422                If this is None, the code will check if there is a `PMG_MAPI_KEY` in
423                your .pmgrc.yaml. If so, it will use that environment
424                This makes easier for heavy users to simply add
425                this environment variable to their setups and MPRester can
426                then be called without any arguments.
427            endpoint (str): Url of endpoint to access the MaterialsProject REST interface.
428                Defaults to the standard Materials Project REST address, but
429                can be changed to other urls implementing a similar interface.
430            nelect: Number of electrons in the unit cell.
431            nspinor: Number of spinor components.
432            line_mode (bool): If True, fetch a BandStructureSymmLine object
433                (default). If False, return the uniform band structure.
434        """
435        # Get pytmatgen structure and convert it to abipy structure
436        from abipy.core import restapi
437        with restapi.get_mprester(api_key=api_key, endpoint=endpoint) as rest:
438            pmgb = rest.get_bandstructure_by_material_id(material_id=material_id, line_mode=line_mode)
439            if pmgb is None: return None
440
441            # Structure is set to None so we have to perform another request and patch the object.
442            structure = rest.get_structure_by_material_id(material_id, final=True)
443            if pmgb.structure is None: pmgb.structure = structure
444
445        if nelect is None:
446            # Get nelect from valence band maximum index.
447            if pmgb.is_metal():
448                cprint("Nelect must be specified if metallic bands.", "red")
449                return None
450            else:
451                d = pmgb.get_vbm()
452                iv_up = max(d["band_index"][PmgSpin.up])
453                nelect = (iv_up + 1) * 2
454                #print("iv_up", iv_up, "nelect: ", nelect)
455                if pmgb.is_spin_polarized:
456                    iv_down = max(d["band_index"][PmgSpin.down])
457                    assert iv_down == iv_up
458
459        #ksampling = KSamplingInfo.from_kbounds(kbounds)
460        return cls.from_pymatgen(pmgb, nelect, weights=None, has_timerev=has_timerev,
461                                 ksampling=None, smearing=None, nspinor=nspinor, nspden=nspden)
462
463    def to_json(self):
464        """
465        Returns a JSON string representation of the MSONable object.
466        """
467        return json.dumps(self.as_dict(), cls=MontyEncoder)
468
469    def __init__(self, structure, kpoints, eigens, fermie, occfacts, nelect, nspinor, nspden,
470                 nband_sk=None, smearing=None, linewidths=None):
471        """
472        Args:
473            structure: |Structure| object.
474            kpoints: |KpointList| instance.
475            eigens: Array-like object with the eigenvalues (eV) stored as [s, k, b]
476                where s: spin , k: kpoint, b: band index
477            fermie: Fermi level in eV.
478            occfacts: Occupation factors (same shape as eigens)
479            nelect: Number of valence electrons in the unit cell.
480            nspinor: Number of spinor components
481            nspden: Number of independent density components.
482            nband_sk: Array-like object with the number of bands treated at each [spin,kpoint]
483                      If not given, nband_sk is initialized from eigens.
484            smearing: :class:`Smearing` object storing information on the smearing technique.
485            linewidths: Array-like object with the linewidths (eV) stored as [s, k, b]
486        """
487        self._structure = structure
488
489        # Eigenvalues and occupancies are stored in ndarrays ordered by [spin,kpt,band]
490        self._eigens = np.atleast_3d(eigens)
491        self._occfacts = np.atleast_3d(occfacts)
492        assert self._eigens.shape == self._occfacts.shape
493        self._linewidths = None
494        if linewidths is not None:
495            self._linewidths = np.reshape(linewidths, self._eigens.shape)
496
497        self.nsppol, self.nkpt, self.mband = self.eigens.shape
498        self.nspinor, self.nspden = nspinor, nspden
499
500        if nband_sk is not None:
501            self.nband_sk = np.array(nband_sk)
502        else:
503            self.nband_sk = np.array(self.nsppol * self.nkpt * [self.mband])
504            self.nband_sk.shape = (self.nsppol, self.nkpt)
505
506        self.kpoints = kpoints
507        assert self.nkpt == len(self.kpoints)
508        assert isinstance(self.kpoints, KpointList)
509
510        self.smearing = {} if smearing is None else smearing
511        self.nelect = float(nelect)
512        self.fermie = float(fermie)
513
514    @property
515    def structure(self):
516        """|Structure| object."""
517        return self._structure
518
519    @lazy_property
520    def _auto_klabels(self):
521        # Find the k-point names in the pymatgen database.
522        # We'll use _auto_klabels to label the point in the matplotlib plot
523        # if klabels are not specified by the user.
524
525        _auto_klabels = OrderedDict()
526        # If the first or the last k-point are not recognized in findname_in_hsym_stars
527        # matplotlib won't show the full band structure along the k-path
528        # because the labels are not defined. So we have to make sure that
529        # the labels for the extrema of the path are always defined.
530        _auto_klabels[0] = " "
531
532        for idx, kpoint in enumerate(self.kpoints):
533            name = kpoint.name if kpoint.name is not None else self.structure.findname_in_hsym_stars(kpoint)
534            #if name is not None:
535            if name:
536                _auto_klabels[idx] = name
537                if kpoint.name is None: kpoint.set_name(name)
538
539        last = len(self.kpoints) - 1
540        if last not in _auto_klabels: _auto_klabels[last] = " "
541
542        return _auto_klabels
543
544    def __repr__(self):
545        """String representation (short version)"""
546        return "<%s, nk=%d, %s, id=%s>" % (self.__class__.__name__, self.nkpt, self.structure.formula, id(self))
547
548    def __str__(self):
549        """String representation"""
550        return self.to_string()
551
552    def __add__(self, other):
553        """self + other returns a |ElectronBandsPlotter|."""
554        if not isinstance(other, (ElectronBands, ElectronBandsPlotter)):
555            raise TypeError("Cannot add %s to %s" % (type(self), type(other)))
556
557        if isinstance(other, ElectronBandsPlotter):
558            self_key = repr(self)
559            other.add_ebands(self_key, self)
560            return other
561        else:
562            plotter = ElectronBandsPlotter()
563            self_key = repr(self)
564            plotter.add_ebands(self_key, self)
565            self_key = repr(self)
566            other_key = repr(other)
567            plotter.add_ebands(other_key, other)
568            return plotter
569
570    __radd__ = __add__
571
572    # Handy variables used to loop
573    @property
574    def spins(self):
575        """Spin range"""
576        return range(self.nsppol)
577
578    @property
579    def nband(self):
580        try:
581            return self._nband
582        except AttributeError:
583            assert np.all(self.nband_sk == self.nband_sk[0])
584            self._nband = self.nband_sk[0, 0]
585            return self._nband
586
587    @property
588    def kidxs(self):
589        """Range with the index of the k-points."""
590        return range(self.nkpt)
591
592    @property
593    def eigens(self):
594        """Eigenvalues in eV. |numpy-array| with shape [nspin, nkpt, mband]."""
595        return self._eigens
596
597    @property
598    def linewidths(self):
599        """linewidths in eV. |numpy-array| with shape [nspin, nkpt, mband]."""
600        return self._linewidths
601
602    @linewidths.setter
603    def linewidths(self, linewidths):
604        """Set the linewidths. Accept real array of shape [nspin, nkpt, mband] or None."""
605        if linewidths is not None:
606            linewidths = np.reshape(linewidths, self.shape)
607        self._linewidths = linewidths
608
609    @property
610    def has_linewidths(self):
611        """True if bands with linewidths."""
612        return getattr(self, "_linewidths", None) is not None
613
614    @property
615    def occfacts(self):
616        """Occupation factors. |numpy-array| with shape [nspin, nkpt, mband]."""
617        return self._occfacts
618
619    @property
620    def reciprocal_lattice(self):
621        """|Lattice| with the reciprocal lattice vectors in Angstrom."""
622        return self.structure.reciprocal_lattice
623
624    @property
625    def shape(self):
626        """Shape of the array with the eigenvalues."""
627        return self.nsppol, self.nkpt, self.mband
628
629    @property
630    def has_metallic_scheme(self):
631        """True if we are using a metallic scheme for occupancies."""
632        if self.smearing:
633            return self.smearing.has_metallic_scheme
634        else:
635            cprint("ebands.smearing is not defined, assuming has_metallic_scheme = False", "red")
636            return False
637
638    def set_fermie_to_vbm(self):
639        """
640        Set the Fermi energy to the valence band maximum (VBM).
641        Useful when the initial fermie energy comes from a GS-SCF calculation
642        that may underestimate the Fermi energy because e.g. the IBZ sampling
643        is shifted whereas the true VMB is at Gamma.
644
645        Return: New fermi energy in eV.
646
647        .. warning:
648
649            Assume spin-unpolarized band energies.
650        """
651        iv = int(self.nelect * self.nspinor) // 2 - 1
652        new_fermie = self.eigens[:, :, iv].max()
653        return self.set_fermie(new_fermie)
654
655    def set_fermie_from_edos(self, edos, nelect=None):
656        """
657        Set the Fermi level using the integrated DOS computed in edos.
658
659         Args:
660            edos: |ElectronDos| object.
661            nelect: Number of electrons. If None, the number of electrons in self. is used
662
663        Return: New fermi energy in eV.
664        """
665        if nelect is None:
666            new_fermie = edos.find_mu(self.nelect)
667        else:
668            new_fermie = edos.find_mu(nelect)
669
670        return self.set_fermie(new_fermie)
671
672    def set_fermie(self, new_fermie):
673        """Set the new fermi energy. Return new value"""
674        self.fermie = new_fermie
675        # TODO change occfacts
676        return self.fermie
677
678    def with_points_along_path(self, frac_bounds=None, knames=None, dist_tol=1e-12):
679        """
680        Build new |ElectronBands| object containing the k-points along the
681        input k-path specified by `frac_bounds`. Useful to extract energies along a path
682        from calculation performed in the IBZ.
683
684        Args:
685            frac_bounds: [M, 3] array  with the vertexes of the k-path in reduced coordinates.
686                If None, the k-path is automatically selected from the structure.
687            knames: List of strings with the k-point labels defining the k-path. It has precedence over frac_bounds.
688            dist_tol: A point is considered to be on the path if its distance from the line
689                is less than dist_tol.
690
691        Return:
692            namedtuple with the following attributes::
693
694                ebands: |ElectronBands| object.
695                ik_new2prev: Correspondence between the k-points in the new ebands and the kpoint
696                    of the previous band structure (self).
697        """
698        # Construct the stars of the k-points for all k-points in self.
699        # In principle, the input k-path is arbitrary and not necessarily in the IBZ used for self
700        # so we have to build the k-stars and find the k-points lying along the path and keep
701        # track of the mapping kpt --> star --> kgw
702        # TODO: This part becomes a bottleneck for large nk!
703        stars = [kpoint.compute_star(self.structure.abi_spacegroup.fm_symmops) for kpoint in self.kpoints]
704        cart_coords, back2istar = [], []
705        for istar, star in enumerate(stars):
706            cart_coords.extend([k.cart_coords for k in star])
707            back2istar.extend([istar] * len(star))
708        cart_coords = np.reshape(cart_coords, (-1, 3))
709
710        if knames is not None:
711            assert frac_bounds is None
712            frac_bounds = self.structure.get_kcoords_from_names(knames)
713        else:
714            if frac_bounds is None:
715                frac_bounds = self.structure.calc_kptbounds()
716
717        # Find (star) k-points on the path.
718        cart_bounds = self.structure.reciprocal_lattice.get_cartesian_coords(frac_bounds)
719        from abipy.core.kpoints import find_points_along_path
720        p = find_points_along_path(cart_bounds, cart_coords, dist_tol=dist_tol)
721        if len(p.ikfound) == 0:
722            raise ValueError("Find zero points lying on the input k-path. Try to increase dist_tol")
723
724        new_eigens = np.zeros((self.nsppol, len(p.ikfound), self.mband))
725        new_occfacts = np.zeros_like(new_eigens)
726        new_linewidths = None if self.linewidths is None else np.zeros_like(new_eigens)
727        new_frac_coords = []
728
729        # Correspondence new.kpoints --> self.ebands.kpoints
730        # Useful if client code has to rearrange other arrays ordered according to self.ebands.kpoints.
731        ik_new2prev = []
732        for ik, ik_new in enumerate(p.ikfound):
733            # Stars are ordered as self.kpoints to this is the index we need to access self.eigens.
734            # and trasfer the data from self to new
735            ik_self = back2istar[ik_new]
736            ik_new2prev.append(ik_self)
737            fcs = self.structure.reciprocal_lattice.get_fractional_coords(cart_coords[ik_new])
738            #print("fcs", fcs, "dist", p.dist_list[ik])
739            new_frac_coords.append(fcs)
740            for spin in range(self.nsppol):
741                new_eigens[spin, ik] = self.eigens[spin, ik_self]
742                new_occfacts[spin, ik] = self.occfacts[spin, ik_self]
743                if self.linewidths is not None:
744                    new_linewidths[spin, ik] = self.linewidths[spin, ik_self]
745
746        new_kpoints = Kpath(self.structure.reciprocal_lattice, new_frac_coords, weights=None, names=None)
747
748        new_ebands = self.__class__(self.structure, new_kpoints, new_eigens, self.fermie, new_occfacts,
749                             self.nelect, self.nspinor, self.nspden,
750                             smearing=self.smearing, linewidths=new_linewidths)
751
752        return dict2namedtuple(ebands=new_ebands, ik_new2prev=ik_new2prev)
753
754    #def select_bands(self, bands, kinds=None):
755    #    """Build new ElectronBands object by selecting bands via band_slice (slice object)."""
756    #    bands = np.array(bands)
757    #    kinds = np.array(kinds) if kinds is not None else np.array(range(self.nkpt))
758    #    # This won't work because I need a KpointList object.
759    #    new_kpoints = self.kpoints[kinds]
760    #    new_eigens = self.eigens[:, kinds, bands].copy()
761    #    new_occfacts = self.occupation[:, kinds, bands].copy()
762    #    new_linewidths = None if not self.linewidths else self.linewidths[:, kinds, bands].copy()
763
764    #    return self.__class__(self.structure, new_kpoints, new_eigens, self.fermie, new_occfacts,
765    #                          self.nelect, self.nspinor, self.nspden,
766    #                          smearing=self.smearing, linewidths=new_linewidths)
767
768    @classmethod
769    def empty_with_ibz(cls, ngkpt, structure, fermie, nelect, nsppol, nspinor, nspden, mband,
770                       shiftk=(0, 0, 0), kptopt=1, smearing=None, linewidths=None):
771        from abipy.abio.factories import gs_input
772        from abipy.data.hgh_pseudos import HGH_TABLE
773        gsinp = gs_input(structure, HGH_TABLE, spin_mode="unpolarized")
774        ibz = gsinp.abiget_ibz(ngkpt=ngkpt, shiftk=shiftk, kptopt=kptopt)
775        ksampling = KSamplingInfo.from_mpdivs(ngkpt, shiftk, kptopt)
776
777        kpoints = IrredZone(structure.reciprocal_lattice, ibz.points, weights=ibz.weights,
778                            names=None, ksampling=ksampling)
779
780        new_eigens = np.zeros((nsppol, len(kpoints), mband))
781        new_occfacts = np.zeros_like(new_eigens)
782
783        return cls(structure, kpoints, new_eigens, fermie, new_occfacts,
784                   nelect, nspinor, nspden,
785                   smearing=smearing, linewidths=linewidths)
786
787    def get_dict4pandas(self, with_geo=True, with_spglib=True):
788        """
789        Return a :class:`OrderedDict` with the most important parameters:
790
791            - Chemical formula and number of atoms.
792            - Lattice lengths, angles and volume.
793            - The spacegroup number computed by Abinit (set to None if not available).
794            - The spacegroup number and symbol computed by spglib (set to None not `with_spglib`).
795
796        Useful to construct pandas DataFrames
797
798        Args:
799            with_geo: True if structure info should be added to the dataframe
800            with_spglib: If True, spglib_ is invoked to get the spacegroup symbol and number.
801        """
802        odict = OrderedDict([
803            ("nsppol", self.nsppol), ("nspinor", self.nspinor), ("nspden", self.nspden),
804            ("nkpt", self.nkpt), ("nband", self.nband_sk.min()),
805            ("nelect", self.nelect), ("fermie", self.fermie),
806
807        ])
808
809        # Add info on structure.
810        if with_geo:
811            odict.update(self.structure.get_dict4pandas(with_spglib=with_spglib))
812
813        odict.update(self.smearing)
814
815        bws = self.bandwidths
816        for spin in self.spins:
817            odict["bandwidth_spin%d" % spin] = bws[spin]
818
819        enough_bands = (self.mband > self.nspinor * self.nelect // 2)
820        if enough_bands:
821            for spin in self.spins:
822                odict["fundgap_spin%d" % spin] = self.fundamental_gaps[spin].energy
823            for spin in self.spins:
824                odict["dirgap_spin%d" % spin] = self.direct_gaps[spin].energy
825
826            # Select min gap over spins.
827            min_fgap = self.fundamental_gaps[0]
828            min_dgap = self.direct_gaps[0]
829            if self.nsppol == 2:
830                fgap0, fgap1 = self.fundamental_gaps[0], self.fundamental_gaps[1]
831                min_fgap = fgap0 if fgap0.energy < fgap1.energy else fgap1
832                dgap0, dgap1 = self.direct_gaps[0], self.direct_gaps[1]
833                min_dgap = dgap0 if dgap0.energy < dgap1.energy else dgap1
834
835            # These quantities are not spin-dependent.
836            odict["gap_type"] = "direct" if min_fgap.is_direct else "indirect"
837            odict["fundgap_kstart"] = repr(min_fgap.in_state.kpoint)
838            odict["fundgap_kend"] = repr(min_fgap.out_state.kpoint)
839            odict["dirgap_kstart"] = repr(min_dgap.in_state.kpoint)
840            odict["dirgap_kend"] = repr(min_dgap.out_state.kpoint)
841
842        return odict
843
844    @lazy_property
845    def has_bzmesh(self):
846        """True if the k-point sampling is homogeneous."""
847        return isinstance(self.kpoints, IrredZone)
848
849    @lazy_property
850    def has_bzpath(self):
851        """True if the bands are computed on a k-path."""
852        return isinstance(self.kpoints, Kpath)
853
854    @lazy_property
855    def kptopt(self):
856        """The value of the kptopt input variable."""
857        try:
858            return self.kpoints.ksampling.kptopt
859        except AttributeError:
860            cprint("ebands.kpoints.ksampling.kptopt is not defined, assuming kptopt = 1", "red")
861            return 1
862
863    @lazy_property
864    def has_timrev(self):
865        """True if time-reversal symmetry is used in the BZ sampling."""
866        return has_timrev_from_kptopt(self.kptopt)
867
868    @lazy_property
869    def supports_fermi_surface(self):
870        """
871        True if the kpoints used for the energies can be employed to visualize Fermi surface.
872        Fermi surface viewers require gamma-centered k-mesh.
873        """
874        if self.kpoints.is_mpmesh:
875            mpdivs, shifts = self.kpoints.mpdivs_shifts
876            if shifts is not None and np.all(shifts == 0.0):
877                return True
878        return False
879
880    def kindex(self, kpoint):
881        """
882        The index of the k-point in the internal list of k-points.
883        Accepts: |Kpoint| instance or integer.
884        """
885        if duck.is_intlike(kpoint):
886            return int(kpoint)
887        else:
888            return self.kpoints.index(kpoint)
889
890    def skb_iter(self):
891        """Iterator over (spin, k, band) indices."""
892        for spin in self.spins:
893            for ik in self.kidxs:
894                for band in range(self.nband_sk[spin, ik]):
895                    yield spin, ik, band
896
897    def deepcopy(self):
898        """Deep copy of the ElectronBands object."""
899        return copy.deepcopy(self)
900
901    def degeneracies(self, spin, kpoint, bands_range, tol_ediff=1.e-3):
902        """
903        Returns a list with the indices of the degenerate bands.
904
905        Args:
906            spin: Spin index.
907            kpoint: K-point index or |Kpoint| object
908            bands_range: List of band indices to analyze.
909            tol_ediff: Tolerance on the energy difference (in eV)
910
911        Returns:
912            List of tuples [(e0, bands_e0), (e1, bands_e1, ....]
913            Each tuple stores the degenerate energy and a list with the band indices.
914            The band structure of silicon at Gamma, for example, will produce something like:
915            [(-6.3, [0]), (5.6, [1, 2, 3]), (8.1, [4, 5, 6])]
916        """
917        # Find the index of the k-point
918        k = self.kindex(kpoint)
919
920        # Extract the energies we are interested in.
921        bands_list = list(bands_range)
922        energies = [self.eigens[spin,k,band] for band in bands_list]
923
924        # Group bands according to their degeneracy.
925        bstart, deg_ebands = bands_list[0], []
926        e0, bs = energies[0], [bstart]
927
928        for band, e in enumerate(energies[1:]):
929            band += (bstart + 1)
930            new_deg = abs(e-e0) > tol_ediff
931
932            if new_deg:
933                ebs = (e0, bs)
934                deg_ebands.append(ebs)
935                e0, bs = e, [band]
936            else:
937                bs.append(band)
938
939        deg_ebands.append((e0, bs))
940        return deg_ebands
941
942    def enemin(self, spin=None, band=None):
943        """Compute the minimum of the eigenvalues."""
944        spin_range = self.spins
945        if spin is not None:
946            assert isinstance(spin, int)
947            spin_range = [spin]
948
949        my_kidxs = self.kidxs
950
951        if band is not None:
952            assert isinstance(band, int)
953            my_bands = [band]
954
955        emin = np.inf
956        for spin in spin_range:
957            for k in my_kidxs:
958                if band is None:
959                    my_bands = range(self.nband_sk[spin,k])
960                for band in my_bands:
961                    e = self.eigens[spin,k,band]
962                    emin = min(emin, e)
963        return emin
964
965    def enemax(self, spin=None, band=None):
966        """Compute the maximum of the eigenvalues."""
967        spin_range = self.spins
968        if spin is not None:
969            assert isinstance(spin, int)
970            spin_range = [spin]
971
972        my_kidxs = self.kidxs
973
974        if band is not None:
975            assert isinstance(band, int)
976            my_bands = [band]
977
978        emax = -np.inf
979        for spin in spin_range:
980            for k in my_kidxs:
981                if band is None:
982                    my_bands = range(self.nband_sk[spin,k])
983                for band in my_bands:
984                    e = self.eigens[spin,k,band]
985                    emax = max(emax, e)
986        return emax
987
988    def dispersionless_states(self, erange=None, deltae=0.05, kfact=0.9):
989        """
990        This function detects dispersionless states.
991        A state is dispersionless if there are more that (nkpt * kfact) energies
992        in the energy intervale [e0 - deltae, e0 + deltae]
993
994        Args:
995            erange=Energy range to be analyzed in the form [emin, emax]
996            deltae: Defines the energy interval in eV around the KS eigenvalue.
997            kfact: Can be used to change the criterion used to detect dispersionless states.
998
999        Returns:
1000            List of :class:`Electron` objects. Each item contains information on
1001            the energy, the occupation and the location of the dispersionless band.
1002        """
1003        if erange is None: erange = [-np.inf, np.inf]
1004        kref = 0
1005        dless_states = []
1006        for spin in self.spins:
1007            for band in range(self.nband_sk[spin,kref]):
1008                e0 = self.eigens[spin, kref, band]
1009                if not erange[1] > e0 > erange[0]: continue
1010                hrange = [e0 - deltae, e0 + deltae]
1011                hist, bin_hedges = np.histogram(self.eigens[spin,:,:],
1012                    bins=2, range=hrange, weights=None, density=False)
1013                #print("hist", hist, "hrange", hrange, "bin_hedges", bin_hedges)
1014
1015                if hist.sum() > self.nkpt * kfact:
1016                    state = self._electron_state(spin, kref, band)
1017                    dless_states.append(state)
1018
1019        return dless_states
1020
1021    def get_dataframe(self, e0="fermie"):
1022        """
1023        Return a |pandas-DataFrame| with the following columns:
1024
1025          ['spin', 'kidx', 'band', 'eig', 'occ', 'kpoint']
1026
1027        where:
1028
1029        ==============  ==========================
1030        Column          Meaning
1031        ==============  ==========================
1032        spin            spin index
1033        kidx            k-point index
1034        band            band index
1035        eig             KS eigenvalue in eV.
1036        occ             Occupation of the state.
1037        kpoint          :class:`Kpoint` object
1038        ==============  ==========================
1039
1040        Args:
1041            e0: Option used to define the zero of energy in the band structure plot. Possible values:
1042                - `fermie`: shift all eigenvalues to have zero energy at the Fermi energy (`self.fermie`).
1043                -  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
1044                -  None: Don't shift energies, equivalent to e0=0
1045                The Fermi energy is saved in frame.fermie
1046        """
1047        rows = []
1048        e0 = self.get_e0(e0)
1049        for spin in self.spins:
1050            for k, kpoint in enumerate(self.kpoints):
1051                for band in range(self.nband_sk[spin,k]):
1052                    eig = self.eigens[spin,k,band] - e0
1053                    rows.append(OrderedDict([
1054                               ("spin", spin),
1055                               ("kidx", k),
1056                               ("band", band),
1057                               ("eig", eig),
1058                               ("occ", self.occfacts[spin, k, band]),
1059                               ("kpoint", self.kpoints[k]),
1060                            ]))
1061
1062        frame = pd.DataFrame(rows, columns=list(rows[0].keys()))
1063        frame.fermie = e0
1064        return frame
1065
1066    @add_fig_kwargs
1067    def boxplot(self, ax=None, e0="fermie", brange=None, swarm=False, **kwargs):
1068        """
1069        Use seaborn_ to draw a box plot to show distributions of eigenvalues with respect to the band index.
1070
1071        Args:
1072            ax: |matplotlib-Axes| or None if a new figure should be created.
1073            e0: Option used to define the zero of energy in the band structure plot. Possible values:
1074                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (``self.fermie``).
1075                -  Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV.
1076                -  None: Don't shift energies, equivalent to ``e0 = 0``.
1077            brange: Only bands such as ``brange[0] <= band_index < brange[1]`` are included in the plot.
1078            swarm: True to show the datapoints on top of the boxes
1079            kwargs: Keyword arguments passed to seaborn boxplot.
1080
1081        Return: |matplotlib-Figure|
1082        """
1083        # Get the dataframe and select bands
1084        frame = self.get_dataframe(e0=e0)
1085        if brange is not None:
1086            frame = frame[(frame["band"] >= brange[0]) & (frame["band"] < brange[1])]
1087
1088        ax, fig, plt = get_ax_fig_plt(ax=ax)
1089        ax.grid(True)
1090
1091        import seaborn as sns
1092        hue = None if self.nsppol == 1 else "spin"
1093        ax = sns.boxplot(x="band", y="eig", data=frame, hue=hue, ax=ax, **kwargs)
1094        if swarm:
1095            sns.swarmplot(x="band", y="eig", data=frame, hue=hue, color=".25", ax=ax)
1096
1097        return fig
1098
1099    @classmethod
1100    def from_pymatgen(cls, pmg_bands, nelect, weights=None, has_timerev=True,
1101                      ksampling=None, smearing=None, nspinor=1, nspden=None):
1102        """
1103        Convert a pymatgen bandstructure object to an Abipy |ElectronBands| object.
1104
1105        Args:
1106            pmg_bands: pymatgen bandstructure object.
1107            nelect: Number of electrons in unit cell.
1108            weights: List of K-points weights (normalized to one, same order as pmg_bands.kpoints).
1109                This argument is optional but recommended when ``pmg_bands`` represents an IBZ sampling.
1110                If weights are not provided, Abipy methods requiring integrations in the BZ won't work.
1111            has_timerev: True if time-reversal symmetry can be used.
1112            ksampling: dictionary with parameters passed to :class:`KSamplingInfo` defining the k-points sampling.
1113                If None, hard-coded values are used. This argument is recommended if IBZ sampling.
1114            smearing: dictionary with parameters passed to :class:`Smearing`
1115                If None, default hard-coded values are used.
1116            nspinor: Number of spinor components.
1117            nspden: Number of independent spin-density components.
1118                If None, nspden is automatically computed from nsppol
1119
1120        .. warning::
1121
1122            The Abipy bandstructure contains more information than the pymatgen object so
1123            the conversion is not complete, especially if you rely on the default values.
1124            Please read carefylly the docstring and the code and use the optional arguments to pass
1125            additional data required by AbiPy if you need a complete conversion.
1126        """
1127        from pymatgen.electronic_structure.bandstructure import BandStructure, BandStructureSymmLine
1128
1129        # Cast to abipy structure and call spglib to init AbinitSpaceGroup.
1130        abipy_structure = Structure.as_structure(pmg_bands.structure.copy())
1131        if not abipy_structure.has_abi_spacegroup:
1132            abipy_structure.spgset_abi_spacegroup(has_timerev)
1133
1134        # Get dimensions.
1135        nsppol = 2 if pmg_bands.is_spin_polarized else 1
1136        if nspden is None:
1137            if nspinor == 1: nspden = nsppol
1138            if nspinor == 2: nspden = 4
1139        nkpt = len(pmg_bands.kpoints)
1140
1141        smearing = Smearing.as_smearing(smearing)
1142        ksampling = KSamplingInfo.as_ksampling(ksampling)
1143
1144        # Build numpy array with eigenvalues.
1145        abipy_eigens = np.empty((nsppol, nkpt, pmg_bands.nb_bands))
1146        abipy_eigens[0] = np.array(pmg_bands.bands[PmgSpin.up]).T.copy()
1147        if nsppol == 2:
1148            abipy_eigens[1] = np.array(pmg_bands.bands[PmgSpin.down]).T.copy()
1149
1150        # Compute occupation factors. Note that pmg bands don't have occfact so
1151        # I have to compute them from the eigens assuming T=0)
1152        atol = 1e-4
1153        abipy_occfacts = np.where(abipy_eigens <= pmg_bands.efermi + atol, 1, 0)
1154        if nsppol == 1: abipy_occfacts *= 2
1155
1156        reciprocal_lattice = pmg_bands.structure.lattice.reciprocal_lattice
1157        frac_coords = np.array([k.frac_coords for k in pmg_bands.kpoints])
1158
1159        if isinstance(pmg_bands, BandStructureSymmLine):
1160            abipy_kpoints = Kpath(reciprocal_lattice, frac_coords,
1161                                  weights=weights, names=None, ksampling=ksampling)
1162
1163        elif isinstance(pmg_bands, BandStructure):
1164            abipy_kpoints = IrredZone(reciprocal_lattice, frac_coords,
1165                                      weights=weights, names=None, ksampling=ksampling)
1166
1167        else:
1168            raise TypeError("Don't know how to handle type: %s" % type(pmg_bands))
1169
1170        # Find names of the kpoints.
1171        for kpoint in abipy_kpoints:
1172            name = abipy_structure.findname_in_hsym_stars(kpoint)
1173
1174        return cls(abipy_structure, abipy_kpoints, abipy_eigens, pmg_bands.efermi, abipy_occfacts,
1175                   nelect, nspinor, nspden, smearing=smearing)
1176
1177    def to_pymatgen(self):
1178        """
1179        Return a pymatgen bandstructure object from an Abipy |ElectronBands| object.
1180        """
1181        from pymatgen.electronic_structure.bandstructure import BandStructure, BandStructureSymmLine
1182        assert np.all(self.nband_sk == self.nband_sk[0, 0])
1183
1184        # eigenvals is a dict of energies for spin up and spin down
1185        # {Spin.up:[][], Spin.down:[][]}, the first index of the array
1186        # [][] refers to the band and the second to the index of the
1187        # kpoint. The kpoints are ordered according to the order of the
1188        # kpoints array. If the band structure is not spin polarized, we
1189        # only store one data set under Spin.up
1190        eigenvals = {PmgSpin.up: self.eigens[0,:,:].T.copy().tolist()}
1191        if self.nsppol == 2:
1192            eigenvals[PmgSpin.down] = self.eigens[1,:,:].T.copy().tolist()
1193
1194        if self.kpoints.is_path:
1195            labels_dict = {k.name: k.frac_coords for k in self.kpoints if k.name is not None}
1196            return BandStructureSymmLine(self.kpoints.frac_coords, eigenvals, self.reciprocal_lattice, self.fermie,
1197                                         labels_dict, coords_are_cartesian=False, structure=self.structure, projections=None)
1198        else:
1199            return BandStructure(self.kpoints.frac_coords, eigenvals, self.reciprocal_lattice, self.fermie,
1200                                 labels_dict=None, coords_are_cartesian=False, structure=self.structure, projections=None)
1201
1202    def _electron_state(self, spin, kpoint, band):
1203        """
1204        Build an instance of :class:`Electron` from the spin, kpoint and band index
1205        """
1206        kidx = self.kindex(kpoint)
1207        return Electron(spin=spin,
1208                        kpoint=self.kpoints[kidx],
1209                        band=band,
1210                        eig=self.eigens[spin, kidx, band],
1211                        occ=self.occfacts[spin, kidx, band],
1212                        kidx=kidx,
1213                        #fermie=self.fermie
1214                        )
1215
1216    @property
1217    def lomos(self):
1218        """lomo states for each spin channel as a list of nsppol :class:`Electron`."""
1219        lomos = self.nsppol * [None]
1220        for spin in self.spins:
1221            lomo_kidx = self.eigens[spin,:,0].argmin()
1222            lomos[spin] = self._electron_state(spin, lomo_kidx, 0)
1223
1224        return lomos
1225
1226    def lomo_sk(self, spin, kpoint):
1227        """
1228        Returns the LOMO state for the given spin, kpoint.
1229
1230        Args:
1231            spin: Spin index
1232            kpoint: Index of the kpoint or |Kpoint| object.
1233        """
1234        return self._electron_state(spin, kpoint, 0)
1235
1236    def homo_sk(self, spin, kpoint):
1237        """
1238        Returns the HOMO state for the given spin, kpoint.
1239
1240        Args:
1241            spin: Spin index
1242            kpoint: Index of the kpoint or |Kpoint| object.
1243        """
1244        k = self.kindex(kpoint)
1245        # Find rightmost value less than or equal to fermie.
1246        b = find_le(self.eigens[spin,k,:], self.fermie + self.pad_fermie)
1247        return self._electron_state(spin, k, b)
1248
1249    def lumo_sk(self, spin, kpoint):
1250        """
1251        Returns the LUMO state for the given spin, kpoint.
1252
1253        Args:
1254            spin: Spin index
1255            kpoint: Index of the kpoint or |Kpoint| object.
1256        """
1257        k = self.kindex(kpoint)
1258        # Find leftmost value greater than fermie.
1259        b = find_gt(self.eigens[spin,k,:], self.fermie + self.pad_fermie)
1260        return self._electron_state(spin, k, b)
1261
1262    @property
1263    def homos(self):
1264        """homo states for each spin channel as a list of nsppol :class:`Electron`."""
1265        homos = self.nsppol * [None]
1266
1267        for spin in self.spins:
1268            blist, enes = [], []
1269            for k in self.kidxs:
1270                # Find rightmost value less than or equal to fermie.
1271                b = find_le(self.eigens[spin,k,:], self.fermie + self.pad_fermie)
1272                blist.append(b)
1273                enes.append(self.eigens[spin,k,b])
1274
1275            homo_kidx = np.array(enes).argmax()
1276            homo_band = blist[homo_kidx]
1277
1278            # Build Electron instance.
1279            homos[spin] = self._electron_state(spin, homo_kidx, homo_band)
1280
1281        return homos
1282
1283    @property
1284    def lumos(self):
1285        """
1286        lumo states for each spin channel as a list of nsppol :class:`Electron`.
1287        """
1288        lumos = self.nsppol * [None]
1289
1290        for spin in self.spins:
1291            blist, enes = [], []
1292            for k in self.kidxs:
1293                # Find leftmost value greater than fermie.
1294                b = find_gt(self.eigens[spin, k, :], self.fermie + self.pad_fermie)
1295                blist.append(b)
1296                enes.append(self.eigens[spin, k, b])
1297
1298            #enes = np.array(enes)
1299            #kinds = np.where(enes == enes.min())[0]
1300            #lumo_kidx = np.asscalar(kinds[len(kinds) // 2])
1301            lumo_kidx = np.array(enes).argmin()
1302            lumo_band = blist[lumo_kidx]
1303
1304            # Build Electron instance.
1305            lumos[spin] = self._electron_state(spin, lumo_kidx, lumo_band)
1306
1307        return lumos
1308
1309    #def is_metal(self, spin)
1310    #    """True if this spin channel is metallic."""
1311    #    if not self.has_metallic_scheme: return False
1312    #    for k in self.kidxs:
1313    #        # Find leftmost value greater than x.
1314    #        b = find_gt(self.eigens[spin,k,:], self.fermie)
1315    #        if self.eigens[spin,k,b] < self.fermie + 0.01:
1316    #            return True
1317
1318    #def is_semimetal(self, spin)
1319    #    """True if this spin channel is semi-metal."""
1320    #    fun_gaps = self.fundamental_gaps
1321    #    for spin in self.spins:
1322    #       if abs(fun_gaps.ene) <  TOL_EGAP
1323
1324    @property
1325    def bandwidths(self):
1326        """The bandwidth for each spin channel i.e. the energy difference (homo - lomo)."""
1327        return [self.homos[spin].eig - self.lomos[spin].eig for spin in self.spins]
1328
1329    @property
1330    def fundamental_gaps(self):
1331        """List of :class:`ElectronTransition` with info on the fundamental gaps for each spin."""
1332        return [ElectronTransition(self.homos[spin], self.lumos[spin]) for spin in self.spins]
1333
1334    @property
1335    def direct_gaps(self):
1336        """List of `nsppol` :class:`ElectronTransition` with info on the direct gaps for each spin."""
1337        dirgaps = self.nsppol * [None]
1338        for spin in self.spins:
1339            gaps = []
1340            for k in self.kidxs:
1341                homo_sk = self.homo_sk(spin, k)
1342                lumo_sk = self.lumo_sk(spin, k)
1343                gaps.append(lumo_sk.eig - homo_sk.eig)
1344
1345            # Find the index of the k-point where the direct gap is located.
1346            # If there multiple k-points along the path, prefer the one in the center
1347            # If not possible e.g. direct at G with G-X-L-G path avoid points on the right border of the graph
1348            gaps = np.array(gaps)
1349            kinds = np.where(gaps == gaps.min())[0]
1350            kdir = kinds[0]
1351            all_kinds = list(zip(kinds, kinds))
1352            #kdir = kinds[len(kinds) // 2]
1353            #kdir = np.array(gaps).argmin()
1354            dirgaps[spin] = ElectronTransition(self.homo_sk(spin, kdir), self.lumo_sk(spin, kdir), all_kinds=all_kinds)
1355
1356        return dirgaps
1357
1358    def get_gaps_string(self, with_latex=True):
1359        """
1360        Return string with info about fundamental and direct gap (if not metallic scheme)
1361
1362        Args:
1363            with_latex: True to get latex symbols for the gap names else text.
1364        """
1365        enough_bands = (self.mband > self.nspinor * self.nelect // 2)
1366        dg_name, fg_name = "direct gap", "fundamental gap"
1367        if with_latex:
1368            dg_name, fg_name = "$E^{dir}_{gap}$", "$E^{fund}_{gap}$"
1369
1370        if enough_bands and not self.has_metallic_scheme:
1371            if self.nsppol == 1:
1372                s = "%s: %s = %.2f, %s = %.2f (eV)" % (
1373                    self.structure.latex_formula,
1374                    dg_name, self.direct_gaps[0].energy,
1375                    fg_name, self.fundamental_gaps[0].energy)
1376            else:
1377                dgs = [t.energy for t in self.direct_gaps]
1378                fgs = [t.energy for t in self.fundamental_gaps]
1379                s = "%s: %s = %.2f (%.2f), %s = %.2f (%.2f) (eV)" % (
1380                    self.structure.latex_formula,
1381                    dg_name, dgs[0], dgs[1],
1382                    fg_name, fgs[0], fgs[1])
1383        else:
1384            s = ""
1385
1386        return s
1387
1388    def get_kpoints_and_band_range_for_edges(self):
1389        """
1390        Find the reduced coordinates and the band indice associate to the band edges.
1391        Important: Call set_fermie_to_vbm() to set the Fermi level to the VBM before calling this method.
1392
1393        Return: (k0_list, effmass_bands_f90) (Fortran notation)
1394        """
1395        from collections import defaultdict
1396        k0_list, effmass_bands_f90 = [], []
1397        for spin in self.spins:
1398            d = defaultdict(lambda: [np.inf, -np.inf])
1399            homo, lumo = self.homos[spin], self.lumos[spin]
1400            k = tuple(homo.kpoint.frac_coords)
1401            d[k][0] = min(d[k][0], homo.band + 1) # C --> F index
1402            k = tuple(lumo.kpoint.frac_coords)
1403            d[k][1] = max(d[k][1], lumo.band + 1)
1404
1405            for k in d:
1406                if d[k][0] == np.inf: d[k][0] = d[k][1]
1407                if d[k][1] == -np.inf: d[k][1] = d[k][0]
1408                if d[k][0] == np.inf or d[k][1] == -np.inf:
1409                    raise RuntimeError("Cannot find band extrema, dict:\n%s:" % str(d))
1410
1411            for k, v in d.items():
1412                k0_list.append(k)
1413                effmass_bands_f90.append(v)
1414
1415        k0_list = np.reshape(k0_list, (-1, 3))
1416        effmass_bands_f90 = np.reshape(effmass_bands_f90, (-1, 2))
1417        #print("k0_list:\n", k0_list, "\neffmass_bands_f90:\n", effmass_bands_f90)
1418
1419        return k0_list, effmass_bands_f90
1420
1421    def to_string(self, title=None, with_structure=True, with_kpoints=False, verbose=0):
1422        """
1423        Human-readable string with useful info such as band gaps, position of HOMO, LOMO...
1424
1425        Args:
1426            with_structure: False if structural info shoud not be displayed.
1427            with_kpoints: False if k-point info shoud not be displayed.
1428            verbose: Verbosity level.
1429        """
1430        lines = []; app = lines.append
1431        if title is not None: app(marquee(title, mark="="))
1432
1433        if with_structure:
1434            app(self.structure.to_string(verbose=verbose, title="Structure"))
1435            app("")
1436
1437        app("Number of electrons: %s, Fermi level: %.3f (eV)" % (self.nelect, self.fermie))
1438        app("nsppol: %d, nkpt: %d, mband: %d, nspinor: %s, nspden: %s" % (
1439           self.nsppol, self.nkpt, self.mband, self.nspinor, self.nspden))
1440        app(str(self.smearing))
1441
1442        def indent(s):
1443            return "    " + s.replace("\n", "\n    ")
1444
1445        if not self.has_metallic_scheme:
1446            enough_bands = (self.mband > self.nspinor * self.nelect // 2)
1447            for spin in self.spins:
1448                if self.nsppol == 2: app(">>> For spin %s" % spin)
1449                if enough_bands:
1450                    # This can fail so we have to catch the exception.
1451                    try:
1452                        app("Direct gap:\n%s" % indent(str(self.direct_gaps[spin])))
1453                        app("Fundamental gap:\n%s" % indent(str(self.fundamental_gaps[spin])))
1454                    except Exception as exc:
1455                        app("WARNING: Cannot compute direct and fundamental gap.")
1456                        if verbose: app("Exception:\n%s" % str(exc))
1457
1458                app("Bandwidth: %.3f (eV)" % self.bandwidths[spin])
1459                if verbose:
1460                    app("Valence minimum located at:\n%s" % indent(str(self.lomos[spin])))
1461
1462                app("Valence maximum located at:\n%s" % indent(str(self.homos[spin])))
1463
1464                try:
1465                    # Cannot assume enough states for this!
1466                    app("Conduction minimum located at:\n%s" % indent(str(self.lumos[spin])))
1467                    app("")
1468                except Exception:
1469                    pass
1470
1471            app("TIP: Call set_fermie_to_vbm() to set the Fermi level to the VBM if this is a non-magnetic semiconductor\n")
1472
1473        if with_kpoints:
1474            app(self.kpoints.to_string(verbose=verbose, title="K-points"))
1475
1476        return "\n".join(lines)
1477
1478    def new_with_irred_kpoints(self, prune_step=None):
1479        """
1480        Return a new |ElectronBands| object in which only the irreducible k-points are kept.
1481        This method is mainly used to prepare the band structure interpolation as the interpolator
1482        will likely fail if the input k-path contains symmetrical k-points.
1483
1484        Args:
1485            prune_step: Optional argument used to select a subset of the irreducible points found.
1486            If ``prune_step`` is None, all irreducible k-points are used.
1487        """
1488        # Get the index of the irreducible kpoints.
1489        from abipy.core.kpoints import find_irred_kpoints_generic
1490        nmt = find_irred_kpoints_generic(self.structure, self.kpoints.frac_coords)
1491        irred_map = nmt.irred_map
1492
1493        if prune_step is not None:
1494            irred_map = irred_map[::prune_step].copy()
1495
1496        # Build new set of k-points
1497        new_kcoords = self.kpoints.frac_coords[irred_map].copy()
1498        new_kpoints = KpointList(self.structure.reciprocal_lattice, new_kcoords,
1499                                 weights=None, names=None, ksampling=self.kpoints.ksampling)
1500
1501        # Extract eigevanlues and occupation factors associated to irred k-points.
1502        new_eigens = self.eigens[:, irred_map, :].copy()
1503        new_occfacts = self.occfacts[:, irred_map, :].copy()
1504
1505        return self.__class__(self.structure, new_kpoints, new_eigens, self.fermie, new_occfacts,
1506                              self.nelect, self.nspinor, self.nspden)
1507
1508    def spacing(self, axis=None):
1509        """
1510        Compute the statistical parameters of the energy spacing, i.e. e[b+1] - e[b]
1511
1512        Returns:
1513            ``namedtuple`` with the statistical parameters in eV
1514        """
1515        ediff = self.eigens[:, :, 1:] - self.eigens[:, :, :self.mband-1]
1516
1517        return StatParams(mean=ediff.mean(axis=axis), stdev=ediff.std(axis=axis),
1518                          min=ediff.min(axis=axis), max=ediff.max(axis=axis))
1519
1520    def statdiff(self, other, axis=None, numpy_op=np.abs):
1521        """
1522        Compare the eigenenergies of two bands and compute the
1523        statistical parameters: mean, standard deviation, min and max
1524        The bands are aligned wrt to their fermi level.
1525
1526        Args:
1527            other: |ElectronBands| object.
1528            axis:  Axis along which the statistical parameters are computed.
1529                The default is to compute the parameters of the flattened array.
1530            numpy_op: Numpy function to apply to the difference of the eigenvalues. The
1531                      default computes ``|self.eigens - other.eigens|``.
1532
1533        Returns:
1534            ``namedtuple`` with the statistical parameters in eV
1535        """
1536        ediff = numpy_op(self.eigens - self.fermie - other.eigens + other.fermie)
1537        return StatParams(mean=ediff.mean(axis=axis), stdev=ediff.std(axis=axis),
1538                          min=ediff.min(axis=axis), max=ediff.max(axis=axis))
1539
1540    def ipw_edos_widget(self): # pragma: no cover
1541        """
1542        Return an ipython widget with controllers to compute the electron DOS.
1543        """
1544        def plot_dos(method, step, width):
1545            edos = self.get_edos(method=method, step=step, width=width)
1546            edos.plot()
1547
1548        import ipywidgets as ipw
1549        return ipw.interact_manual(
1550                plot_dos,
1551                method=["gaussian", "tetra"],
1552                step=ipw.FloatSlider(value=0.1, min=1e-6, max=1, step=0.05, description="Step of linear mesh (eV)"),
1553                width=ipw.FloatSlider(value=0.2, min=1e-6, max=1, step=0.05, description="Gaussian broadening (eV)"),
1554            )
1555
1556    def get_edos(self, method="gaussian", step=0.1, width=0.2):
1557        """
1558        Compute the electronic DOS on a linear mesh.
1559
1560        Args:
1561            method: String defining the method for the computation of the DOS.
1562            step: Energy step (eV) of the linear mesh.
1563            width: Standard deviation (eV) of the gaussian.
1564
1565        Returns: |ElectronDos| object.
1566        """
1567        self.kpoints.check_weights()
1568
1569        # Compute linear mesh.
1570        epad = 3.0 * width
1571        e_min = self.enemin() - epad
1572        e_max = self.enemax() + epad
1573        nw = int(1 + (e_max - e_min) / step)
1574        mesh, step = np.linspace(e_min, e_max, num=nw, endpoint=True, retstep=True)
1575
1576        # TODO: Write cython version.
1577        dos = np.zeros((self.nsppol, nw))
1578        if method == "gaussian":
1579            for spin in self.spins:
1580                for k, kpoint in enumerate(self.kpoints):
1581                    weight = kpoint.weight
1582                    for band in range(self.nband_sk[spin,k]):
1583                        e = self.eigens[spin,k,band]
1584                        dos[spin] += weight * gaussian(mesh, width, center=e)
1585
1586        else:
1587            raise NotImplementedError("Method %s is not supported" % method)
1588
1589        # Use fermie from Abinit if we are not using metallic scheme for occopt.
1590        fermie = None
1591        #if self.smearing["occopt"] == 1:
1592        #    print("using fermie from GSR")
1593        #    fermie = self.fermie
1594        edos = ElectronDos(mesh, dos, self.nelect, fermie=fermie)
1595        #print("ebands.fermie", self.fermie, "edos.fermie", edos.fermie)
1596        return edos
1597
1598    def compare_gauss_edos(self, widths, step=0.1):
1599        """
1600        Compute the electronic DOS with the Gaussian method for different values
1601        of the broadening. Return plotter object.
1602
1603        Args:
1604            widths: List with the tandard deviation (eV) of the gaussian.
1605            step: Energy step (eV) of the linear mesh.
1606
1607        Return: |ElectronDosPlotter|
1608        """
1609        edos_plotter = ElectronDosPlotter()
1610        for width in widths:
1611            edos = self.get_edos(method="gaussian", step=0.1, width=width)
1612            label = r"$\sigma = %s$ (eV)" % width
1613            edos_plotter.add_edos(label, edos)
1614
1615        return edos_plotter
1616
1617    @add_fig_kwargs
1618    def plot_transitions(self, omega_ev, qpt=(0, 0, 0), atol_ev=0.1, atol_kdiff=1e-4,
1619                         ylims=None, ax=None, alpha=0.4, **kwargs):
1620        """
1621        Plot energy bands with arrows signaling possible k --> k + q indipendent-particle transitions
1622        of energy ``omega_ev`` connecting occupied to empty states.
1623
1624        Args:
1625            omega_ev: Transition energy in eV.
1626            qpt: Q-point in reduced coordinates.
1627            atol_ev: Absolute tolerance for energy difference in eV
1628            atol_kdiff: Tolerance used to compare k-points.
1629            ylims: Set the data limits for the y-axis. Accept tuple e.g. `(left, right)`
1630                or scalar e.g. `left`. If left (right) is None, default values are used
1631            alpha: The alpha blending value, between 0 (transparent) and 1 (opaque)
1632            ax: |matplotlib-Axes| or None if a new figure should be created.
1633
1634        Returns: |matplotlib-Figure|
1635        """
1636        ax, fig, plt = get_ax_fig_plt(ax=ax)
1637        e0 = self.get_e0("fermie")
1638        self.plot(ax=ax, e0=e0, ylims=ylims, show=False)
1639
1640        # Pre-compute mapping k_index --> (k + q)_index, g0
1641        k2kqg = self.kpoints.get_k2kqg_map(qpt, atol_kdiff=atol_kdiff)
1642
1643        # Add arrows to the plot (different colors for spin up/down)
1644        from matplotlib.patches import FancyArrowPatch
1645        for spin in self.spins:
1646            cachek = {}
1647            arrow_opts = {"color": "k"} if spin == 0 else {"color": "red"}
1648            arrow_opts.update(dict(lw=2, arrowstyle="-|>",))
1649            for ik, (ikq, g0) in k2kqg.items():
1650                dx = ikq - ik
1651                ek = self.eigens[spin, ik]
1652                ekq = self.eigens[spin, ikq]
1653                # Find rightmost value less than or equal to fermie.
1654                nv_k = cachek.get(ik)
1655                if nv_k is None:
1656                    nv_k = find_le(ek, self.fermie + self.pad_fermie)
1657                    cachek[ik] = nv_k
1658
1659                if ik == ikq:
1660                    nv_kq = nv_k
1661                else:
1662                    nv_kq = cachek.get(ikq)
1663                    if nv_kq is None:
1664                        nv_kq = find_le(ekq, self.fermie + self.pad_fermie)
1665                        cachek[ikq] = nv_kq
1666
1667                #print("nv_k:", nv_k, "nc_kq", nv_kq)
1668                for v_k in range(nv_k):
1669                    for c_kq in range(nv_kq + 1, self.nband):
1670                        dy = self.eigens[spin, ikq, c_kq] - self.eigens[spin, ik, v_k]
1671                        if abs(dy - omega_ev) > atol_ev: continue
1672                        y = self.eigens[spin, ik, v_k] - e0
1673                        # http://matthiaseisen.com/matplotlib/shapes/arrow/
1674                        p = FancyArrowPatch((ik, y), (ik + dx, y + dy),
1675                                connectionstyle='arc3', mutation_scale=20,
1676                                alpha=alpha, **arrow_opts)
1677                        ax.add_patch(p)
1678        return fig
1679
1680    def get_ejdos(self, spin, valence, conduction, method="gaussian", step=0.1, width=0.2, mesh=None):
1681        r"""
1682        Compute the join density of states at q == 0.
1683
1684            :math:`\sum_{kbv} f_{vk} (1 - f_{ck}) \delta(\omega - E_{ck} + E_{vk})`
1685
1686        .. warning::
1687
1688            The present implementation assumes an energy gap
1689
1690        Args:
1691            spin: Spin index.
1692            valence: Int or iterable with the valence indices.
1693            conduction: Int or iterable with the conduction indices.
1694            method (str): String defining the integraion method.
1695            step: Energy step (eV) of the linear mesh.
1696            width: Standard deviation (eV) of the gaussian.
1697            mesh: Frequency mesh to use. If None, the mesh is computed automatically from the eigenvalues.
1698
1699        Returns: |Function1D| object.
1700        """
1701        # TODO: Generalize to k+q with
1702        # k2kqg = self.kpoints.get_k2kqg_map(qpt, atol_kdiff=atol_kdiff)
1703        self.kpoints.check_weights()
1704        if not isinstance(valence, Iterable): valence = [valence]
1705        if not isinstance(conduction, Iterable): conduction = [conduction]
1706
1707        if mesh is None:
1708            # Compute the linear mesh.
1709            cmin, cmax = +np.inf, -np.inf
1710            vmin, vmax = +np.inf, -np.inf
1711            for c in conduction:
1712                cmin = min(cmin, self.eigens[spin,:,c].min())
1713                cmax = max(cmax, self.eigens[spin,:,c].max())
1714            for v in valence:
1715                vmin = min(vmin, self.eigens[spin,:,v].min())
1716                vmax = max(vmax, self.eigens[spin,:,v].max())
1717
1718            e_min = cmin - vmax
1719            e_min -= 0.1 * abs(e_min)
1720            e_max = cmax - vmin
1721            e_max += 0.1 * abs(e_max)
1722
1723            nw = int(1 + (e_max - e_min) / step)
1724            mesh, step = np.linspace(e_min, e_max, num=nw, endpoint=True, retstep=True)
1725        else:
1726            nw = len(mesh)
1727
1728        jdos = np.zeros(nw)
1729
1730        # Normalize the occupation factors.
1731        full = 2.0 if self.nsppol == 1 else 1.0
1732
1733        if method == "gaussian":
1734            for k, kpoint in enumerate(self.kpoints):
1735                weight = kpoint.weight
1736                for c in conduction:
1737                    ec = self.eigens[spin, k, c]
1738                    fc = 1.0 - self.occfacts[spin,k,c] / full
1739                    for v in valence:
1740                        ev = self.eigens[spin, k, v]
1741                        fv = self.occfacts[spin, k, v] / full
1742                        fact = weight * fv * fc
1743                        jdos += fact * gaussian(mesh, width, center=ec-ev)
1744
1745        else:
1746            raise NotImplementedError("Method %s is not supported" % str(method))
1747
1748        return Function1D(mesh, jdos)
1749
1750    @add_fig_kwargs
1751    def plot_ejdosvc(self, vrange, crange, method="gaussian", step=0.1, width=0.2, colormap="jet",
1752                     cumulative=True, ax=None, alpha=0.7, fontsize=12, **kwargs):
1753        """
1754        Plot the decomposition of the joint-density of States (JDOS).
1755
1756        .. warning::
1757
1758            The present implementation assumes an energy gap
1759
1760        Args:
1761            vrange: Int or `Iterable` with the indices of the valence bands to consider.
1762            crange: Int or `Iterable` with the indices of the conduction bands to consider.
1763            method: String defining the method.
1764            step: Energy step (eV) of the linear mesh.
1765            width: Standard deviation (eV) of the gaussian.
1766            colormap: Have a look at the colormaps here and decide which one you like:
1767                http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html
1768            cumulative: True for cumulative plots (default).
1769            ax: |matplotlib-Axes| or None if a new figure should be created.
1770            alpha: The alpha blending value, between 0 (transparent) and 1 (opaque)
1771            fontsize: fontsize for legends and titles
1772
1773        Returns: |matplotlib-Figure|
1774        """
1775        if not isinstance(crange, Iterable): crange = [crange]
1776        if not isinstance(vrange, Iterable): vrange = [vrange]
1777
1778        ax, fig, plt = get_ax_fig_plt(ax=ax)
1779        ax.grid(True)
1780        ax.set_xlabel('Energy (eV)')
1781        cmap = plt.get_cmap(colormap)
1782        lw = kwargs.pop("lw", 1.0)
1783
1784        for s in self.spins:
1785            spin_sign = +1 if s == 0 else -1
1786
1787            # Get total JDOS for this spin
1788            tot_jdos = spin_sign * self.get_ejdos(s, vrange, crange, method=method, step=step, width=width)
1789
1790            # Decomposition in terms of v --> c transitions.
1791            jdos_vc = OrderedDict()
1792            for v in vrange:
1793                for c in crange:
1794                    jd = self.get_ejdos(s, v, c, method=method, step=step, width=width, mesh=tot_jdos.mesh)
1795                    jdos_vc[(v, c)] = spin_sign * jd
1796
1797            # Plot data for this spin.
1798            if cumulative:
1799                cumulative = np.zeros(len(tot_jdos))
1800                num_plots, i = len(jdos_vc), 0
1801                for (v, c), jdos in jdos_vc.items():
1802                    label = r"$v=%s \rightarrow c=%s, \sigma=%s$" % (v, c, s)
1803                    color = cmap(float(i) / num_plots)
1804                    x, y = jdos.mesh, jdos.values
1805                    ax.plot(x, cumulative + y, lw=lw, label=label, color=color)
1806                    ax.fill_between(x, cumulative, cumulative + y, facecolor=color, alpha=alpha)
1807                    cumulative += jdos.values
1808                    i += 1
1809            else:
1810                num_plots, i = len(jdos_vc), 0
1811                for (v, c), jdos in jdos_vc.items():
1812                    color = cmap(float(i) / num_plots)
1813                    jdos.plot_ax(ax, color=color, lw=lw, label=r"$v=%s \rightarrow c=%s, \sigma=%s$" % (v, c, s))
1814                    i += 1
1815
1816            tot_jdos.plot_ax(ax, color="k", lw=lw, label=r"Total JDOS, $\sigma=%s$" % s)
1817
1818        ax.legend(loc="best", shadow=True, fontsize=fontsize)
1819
1820        return fig
1821
1822    def apply_scissors(self, scissors):
1823        """
1824        Modify the band structure with the scissors operator.
1825
1826        Args:
1827            scissors: An instance of :class:`Scissors`.
1828
1829        Returns:
1830            New instance of |ElectronBands| with modified energies.
1831        """
1832        if self.nsppol == 1 and not isinstance(scissors, Iterable):
1833            scissors = [scissors]
1834        if self.nsppol == 2 and len(scissors) != 2:
1835            raise ValueError("Expecting two scissors operators for spin up and down")
1836
1837        # Create new array with same shape as self.
1838        qp_energies = np.zeros(self.shape)
1839
1840        # Calculate quasi-particle energies with the scissors operator.
1841        for spin in self.spins:
1842            sciss = scissors[spin]
1843            for k in self.kidxs:
1844                for band in range(self.nband_sk[spin,k]):
1845                    e0 = self.eigens[spin,k,band]
1846                    try:
1847                        qp_ene = e0 + sciss.apply(e0)
1848                    except sciss.Error:
1849                        raise
1850
1851                    # Update the energy.
1852                    qp_energies[spin,k,band] = qp_ene
1853
1854        # Apply the scissors to the Fermi level as well.
1855        # NB: This should be ok for semiconductors in which fermie == CBM (abinit convention)
1856        # and there's usually one CBM state whose QP correction is expected to be reproduced
1857        # almost exactly by the polyfit.
1858        # Not sure about metals. Besides occupations are not changed here!
1859        fermie = self.fermie + scissors[0].apply(self.fermie)
1860        #fermie = self.fermie
1861        print("KS fermie", self.fermie, "--> QP fermie", fermie, "Delta(QP-KS)=", fermie - self.fermie)
1862
1863        return self.__class__(
1864            self.structure, self.kpoints, qp_energies, fermie, self.occfacts, self.nelect, self.nspinor, self.nspden,
1865            nband_sk=self.nband_sk, smearing=self.smearing)
1866
1867    @add_fig_kwargs
1868    def plot(self, spin=None, band_range=None, klabels=None, e0="fermie", ax=None, ylims=None,
1869             points=None, with_gaps=False, max_phfreq=None, fontsize=8, **kwargs):
1870        r"""
1871        Plot the electronic band structure.
1872
1873        Args:
1874            spin: Spin index. None to plot both spins.
1875            band_range: Tuple specifying the minimum and maximum band to plot (default: all bands are plotted)
1876            klabels: dictionary whose keys are tuple with the reduced
1877                coordinates of the k-points. The values are the labels. e.g.
1878                ``klabels = {(0.0,0.0,0.0): "$\Gamma$", (0.5,0,0):"L"}``.
1879            e0: Option used to define the zero of energy in the band structure plot. Possible values:
1880                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (``self.fermie``).
1881                -  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
1882                -  None: Don't shift energies, equivalent to e0=0
1883            ax: |matplotlib-Axes| or None if a new figure should be created.
1884            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
1885                   or scalar e.g. ``left``. If left (right) is None, default values are used
1886            points: Marker object with the position and the size of the marker.
1887                Used for plotting purpose e.g. QP energies, energy derivatives...
1888            with_gaps: True to add markers and arrows showing the fundamental and the direct gap.
1889                IMPORTANT: If the gaps are now showed correctly in a non-magnetic semiconductor,
1890                    call `ebands.set_fermie_to_vbm()` to align the Fermi level at the top of the valence
1891                    bands before executing `ebands.plot().
1892                    The Fermi energy stored in the object, indeed, comes from the GS calculation
1893                    that produced the DEN file. If the k-mesh used for the GS and the CBM is e.g. at Gamma,
1894                    the Fermi energy will be underestimated and a manual aligment is needed.
1895            max_phfreq: Max phonon frequency in eV to activate scatterplot showing
1896                possible phonon absorption/emission processes based on energy-conservation alone.
1897                All final states whose energy is within +- max_phfreq of the initial state are included.
1898                By default, the four electronic states defining the fundamental and the direct gaps
1899                are considered as initial state (not available for metals).
1900            fontsize: fontsize for legends and titles
1901
1902        Returns: |matplotlib-Figure|
1903        """
1904        # Select spins
1905        spin_list = self.spins if spin is None else [spin]
1906
1907        # Select the band range.
1908        if band_range is None:
1909            band_list = list(range(self.mband))
1910        else:
1911            # This does not work in py2.7 because range is not a class
1912            #if not isinstance(band_range, range):
1913            #    band_list = list(band_range)
1914            band_list = list(range(band_range[0], band_range[1], 1))
1915
1916        e0 = self.get_e0(e0)
1917        ax, fig, plt = get_ax_fig_plt(ax=ax)
1918
1919        # Decorate the axis (e.g add ticks and labels).
1920        self.decorate_ax(ax, klabels=klabels)
1921        set_axlims(ax, ylims, "y")
1922
1923        # Plot the band energies.
1924        for spin in spin_list:
1925            opts = {"color": "black", "linewidth": 2.0} if spin == 0 else \
1926                   {"color": "red", "linewidth": 2.0}
1927            # This to pass kwargs to plot_ax and avoid both lw and linewidth in opts
1928            if "lw" in kwargs: opts.pop("linewidth")
1929            opts.update(kwargs)
1930
1931            for ib, band in enumerate(band_list):
1932                if ib != 0: opts.pop("label", None)
1933                self.plot_ax(ax, e0, spin=spin, band=band, **opts)
1934
1935        if points is not None:
1936            ax.scatter(points.x, np.array(points.y) - e0, s=np.abs(points.s), marker="o", c="b")
1937
1938        if with_gaps and (self.mband > self.nspinor * self.nelect // 2):
1939            # Show fundamental and direct gaps for each spin.
1940            from matplotlib.patches import FancyArrowPatch
1941            for spin in self.spins:
1942                f_gap = self.fundamental_gaps[spin]
1943                d_gap = self.direct_gaps[spin]
1944                # Need arrows only if fundamental and direct gaps for this spin are different.
1945                need_arrows = f_gap != d_gap
1946
1947                arrow_opts = {"color": "k"} if spin == 0 else {"color": "red"}
1948                arrow_opts.update(lw=2, alpha=0.6, arrowstyle="->", connectionstyle='arc3',
1949                                  mutation_scale=20, zorder=1000)
1950                scatter_opts = {"color": "blue"} if spin == 0 else {"color": "green"}
1951                scatter_opts.update(marker="o", alpha=1.0, s=80, zorder=100, edgecolor='black')
1952
1953                # Fundamental gap.
1954                mgap = -1
1955                for ik1, ik2 in f_gap.all_kinds:
1956                    posA = (ik1, f_gap.in_state.eig - e0)
1957                    posB = (ik2, f_gap.out_state.eig - e0)
1958                    mgap = max(mgap, posA[1], posB[1])
1959                    ax.scatter(posA[0], posA[1], **scatter_opts)
1960                    ax.scatter(posB[0], posB[1], **scatter_opts)
1961                    if need_arrows:
1962                        ax.add_patch(FancyArrowPatch(posA=posA, posB=posB, **arrow_opts))
1963
1964                if d_gap != f_gap:
1965                    # Direct gap.
1966                    for ik1, ik2 in d_gap.all_kinds:
1967                        posA = (ik1, d_gap.in_state.eig - e0)
1968                        posB = (ik2, d_gap.out_state.eig - e0)
1969                        mgap = max(mgap, posA[1], posB[1])
1970                        ax.scatter(posA[0], posA[1], **scatter_opts)
1971                        ax.scatter(posB[0], posB[1], **scatter_opts)
1972                        if need_arrows:
1973                            ax.add_patch(FancyArrowPatch(posA=posA, posB=posB, **arrow_opts))
1974
1975            # Try to set nice limits if not given by user.
1976            if ylims is None:
1977                set_axlims(ax, (-mgap - 5, +mgap + 5), "y")
1978
1979            gaps_string = self.get_gaps_string()
1980            if gaps_string:
1981                ax.set_title(gaps_string, fontsize=fontsize)
1982
1983        if max_phfreq is not None and (self.mband > self.nspinor * self.nelect // 2):
1984            # Add markers showing phonon absorption/emission processes.
1985            for spin in self.spins:
1986                #scatter_opts = {"color": "steelblue"} if spin == 0 else {"color": "teal"}
1987                scatter_opts = dict(alpha=0.4, s=40, zorder=10)
1988                items = (["fundamental_gaps", "direct_gaps"], ["in_state", "out_state"])
1989                items = list(enumerate(itertools.product(*items)))
1990                for i, (gap_name, state_name) in items:
1991                    # Use getattr to extract gaps, equivalent to:
1992                    #   gap = self.fundamental_gaps[spin]
1993                    #   e_start = gap.out_state.eig
1994                    gap = getattr(self, gap_name)[spin]
1995                    e_start = getattr(gap, state_name).eig
1996                    scatter_opts["marker"] = "o"
1997                    scatter_opts["color"] = plt.get_cmap("cool" if spin == 0 else "summer")(i/len(items))
1998
1999                    for band in range(self.mband):
2000                        eks = self.eigens[spin, :, band]
2001                        where = np.where(np.abs(e_start - eks) <= max_phfreq)[0]
2002                        if not np.any(where): continue
2003                        ax.scatter(where, eks[where] - e0, **scatter_opts)
2004
2005        return fig
2006
2007    @add_fig_kwargs
2008    def plot_scatter3d(self, band, spin=0, e0="fermie", colormap="jet", ax=None, **kwargs):
2009        r"""
2010        Use matplotlib ``scatter3D`` to produce a scatter plot of the eigenvalues in 3D.
2011        The color of the points gives the energy of the state wrt to the Fermi level.
2012
2013        Args:
2014            band: Band index
2015            spin: Spin index.
2016            e0: Option used to define the zero of energy in the band structure plot. Possible values:
2017                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (``self.fermie``).
2018                -  Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV
2019                -  None: Don't shift energies, equivalent to ``e0 = 0``
2020            colormap: Have a look at the colormaps here and decide which one you like:
2021                <http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html>
2022            ax: matplotlib :class:`Axes3D` or None if a new figure should be created.
2023        """
2024        kcart_coords = self.kpoints.get_cart_coords()
2025        c = self.eigens[spin, :, band] - self.get_e0(e0)
2026
2027        ax, fig, plt = get_ax3d_fig_plt(ax)
2028        cmap = plt.get_cmap(colormap)
2029        #ax.scatter3D(xs, ys, zs, s=6, alpha=0.8, marker=',', facecolors=cmap(N), lw=0)
2030        p = ax.scatter3D(kcart_coords[:, 0], kcart_coords[:, 1], zs=kcart_coords[:, 2], zdir='z',
2031                         s=20, c=c, depthshade=True, cmap=cmap)
2032
2033        #self.structure.plot_bz(ax=ax, pmg_path=False, with_labels=False, show=False, linewidth=0)
2034        from pymatgen.electronic_structure.plotter import plot_wigner_seitz
2035        plot_wigner_seitz(self.structure.reciprocal_lattice, ax=ax, linewidth=1)
2036        ax.set_xlabel("$K_x$")
2037        ax.set_ylabel("$K_y$")
2038        ax.set_zlabel("$K_z$")
2039        fig.colorbar(p)
2040
2041        #ax.set_title(structure.composition.formula)
2042        ax.set_axis_off()
2043
2044        return fig
2045
2046    def decorate_ax(self, ax, **kwargs):
2047        """
2048        Add k-labels, title and unit name to axis ax.
2049
2050        Args:
2051            title:
2052            fontsize
2053            klabels:
2054            klabel_size:
2055        """
2056        title = kwargs.pop("title", None)
2057        fontsize = kwargs.pop("fontsize", 12)
2058        if title is not None: ax.set_title(title, fontsize=fontsize)
2059
2060        ax.grid(True)
2061        ax.set_ylabel("Energy (eV)")
2062        ax.set_xlabel("Wave Vector")
2063
2064        # Set ticks and labels.
2065        klabels = kwargs.pop("klabels", None)
2066        ticks, labels = self._make_ticks_and_labels(klabels)
2067        if ticks:
2068            # Don't show label if previous k-point is the same.
2069            for il in range(1, len(labels)):
2070                if labels[il] == labels[il-1]: labels[il] = ""
2071            #print("ticks", ticks, "\nlabels", labels)
2072            ax.set_xticks(ticks, minor=False)
2073            ax.set_xticklabels(labels, fontdict=None, minor=False, size=kwargs.pop("klabel_size", "large"))
2074            #print("ticks", len(ticks), ticks)
2075            ax.set_xlim(ticks[0], ticks[-1])
2076
2077    def get_e0(self, e0):
2078        """
2079        e0: Option used to define the zero of energy in the band structure plot. Possible values:
2080                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (``self.fermie``).
2081                -  Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV
2082                -  None: Don't shift energies, equivalent to ``e0 = 0``.
2083        """
2084        if e0 is None:
2085            return 0.0
2086        elif is_string(e0):
2087            if e0 == "fermie":
2088                return self.fermie
2089            elif e0 == "None":
2090                return 0.0
2091            else:
2092                raise ValueError("Wrong value for e0: %s" % e0)
2093        else:
2094            # Assume number
2095            return e0
2096
2097    def plot_ax(self, ax, e0, spin=None, band=None, **kwargs):
2098        """
2099        Helper function to plot the energies for (spin, band) on the axis ax.
2100
2101        Args:
2102            ax: |matplotlib-Axes|.
2103            e0: Option used to define the zero of energy in the band structure plot.
2104            spin: Spin index. If None, all spins are plotted.
2105            band: Band index, If None, all bands are plotted.
2106            kwargs: Passed to ax.plot
2107
2108        Return: matplotlib lines
2109        """
2110        spin_range = range(self.nsppol) if spin is None else [spin]
2111        band_range = range(self.mband) if band is None else [band]
2112
2113        label = kwargs.pop("label", None)
2114        # Handle linewidths
2115        with_linewidths = kwargs.pop("with_linewidths", True) and self.has_linewidths
2116        if with_linewidths:
2117            lw_opts = kwargs.pop("lw_opts", dict(alpha=0.6))
2118            lw_fact = lw_opts.pop("fact", 2.0)
2119
2120        xx, lines = np.arange(self.nkpt), []
2121        e0 = self.get_e0(e0)
2122        for spin in spin_range:
2123            for band in band_range:
2124                yy = self.eigens[spin, :, band] - e0
2125
2126                # Set label only at the first iteration
2127                lines.extend(ax.plot(xx, yy, label=label, **kwargs))
2128                label = None
2129
2130                if with_linewidths:
2131                    w = self.linewidths[spin, :, band] * lw_fact / 2
2132                    lw_color = lines[-1].get_color()
2133                    ax.fill_between(xx, yy - w, yy + w, facecolor=lw_color, **lw_opts)
2134                    #, alpha=self.alpha, facecolor=self.l2color[l])
2135
2136        return lines
2137
2138    def _make_ticks_and_labels(self, klabels):
2139        """Return ticks and labels from the mapping qlabels."""
2140        if klabels is not None:
2141            d = OrderedDict()
2142            for kcoord, kname in klabels.items():
2143                # Build Kpoint instance.
2144                ktick = Kpoint(kcoord, self.reciprocal_lattice)
2145                for idx, kpt in enumerate(self.kpoints):
2146                    if ktick == kpt: d[idx] = kname
2147
2148        else:
2149            d = self._auto_klabels
2150
2151        # Return ticks, labels
2152        return list(d.keys()), list(d.values())
2153
2154    @add_fig_kwargs
2155    def plot_with_edos(self, edos, klabels=None, ax_list=None, e0="fermie", points=None,
2156                       with_gaps=False, max_phfreq=None, ylims=None, width_ratios=(2, 1), **kwargs):
2157        r"""
2158        Plot the band structure and the DOS.
2159
2160        Args:
2161            edos: An instance of |ElectronDos|.
2162            klabels: dictionary whose keys are tuple with the reduced coordinates of the k-points.
2163                The values are the labels. e.g. ``klabels = {(0.0,0.0,0.0): "$\Gamma$", (0.5,0,0): "L"}``.
2164            ax_list: The axes for the bandstructure plot and the DOS plot. If ax_list is None, a new figure
2165                is created and the two axes are automatically generated.
2166            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
2167                   or scalar e.g. ``left``. If left (right) is None, default values are used
2168            e0: Option used to define the zero of energy in the band structure plot. Possible values::
2169
2170                * ``fermie``: shift all eigenvalues and the DOS to have zero energy at the Fermi energy.
2171                   Note that, by default, the Fermi energy is taken from the band structure object
2172                   i.e. the Fermi energy computed at the end of the SCF file that produced the density.
2173                   This should be ok in semiconductors. In metals, however, a better value of the Fermi energy
2174                   can be obtained from the DOS provided that the k-sampling for the DOS is much denser than
2175                   the one used to compute the density. See ``edos_fermie``.
2176                * ``edos_fermie``: Use the Fermi energy computed from the DOS to define the zero of energy in both subplots.
2177                *  Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV
2178                *  None: Don't shift energies, equivalent to ``e0 = 0``
2179
2180            points: Marker object with the position and the size of the marker.
2181                Used for plotting purpose e.g. QP energies, energy derivatives...
2182            with_gaps: True to add markers and arrows showing the fundamental and the direct gap.
2183            max_phfreq: Max phonon frequency in eV to activate scatterplot showing
2184                possible phonon absorption/emission processes based on energy-conservation alone.
2185                All final states whose energy is within +- max_phfreq of the initial state are included.
2186                By default, the four electronic states defining the fundamental and the direct gaps
2187                are considered as initial state (not available for metals).
2188            width_ratios: Defines the ratio between the band structure plot and the dos plot.
2189
2190        Return: |matplotlib-Figure|
2191        """
2192        import matplotlib.pyplot as plt
2193        from matplotlib.gridspec import GridSpec
2194
2195        if ax_list is None:
2196            # Build axes and align bands and DOS.
2197            fig = plt.figure()
2198            gspec = GridSpec(nrows=1, ncols=2, width_ratios=width_ratios, wspace=0.05)
2199            ax0 = plt.subplot(gspec[0])
2200            ax1 = plt.subplot(gspec[1], sharey=ax0)
2201        else:
2202            # Take them from ax_list.
2203            ax0, ax1 = ax_list
2204            fig = plt.gcf()
2205
2206        # Define the zero of energy.
2207        e0 = self.get_e0(e0) if e0 != "edos_fermie" else edos.fermie
2208        #if not kwargs: kwargs = {"color": "black", "linewidth": 2.0}
2209
2210        # Plot the band structure
2211        self.plot(e0=e0, ax=ax0, ylims=ylims, klabels=klabels, points=points,
2212                  with_gaps=with_gaps, max_phfreq=max_phfreq, show=False)
2213
2214        # Plot the DOS
2215        if self.nsppol == 1:
2216            opts = {"color": "black", "linewidth": 2.0}
2217            edos.plot_ax(ax1, e0, exchange_xy=True, **opts)
2218        else:
2219            for spin in self.spins:
2220                opts = {"color": "black", "linewidth": 2.0} if spin == 0 else \
2221                       {"color": "red", "linewidth": 2.0}
2222                edos.plot_ax(ax1, e0, spin=spin, exchange_xy=True, **opts)
2223
2224        ax1.grid(True)
2225        ax1.yaxis.set_ticks_position("right")
2226        ax1.yaxis.set_label_position("right")
2227        set_axlims(ax1, ylims, "y")
2228
2229        return fig
2230
2231    @add_fig_kwargs
2232    def plot_lws_vs_e0(self, ax=None, e0="fermie", function=lambda x: x, exchange_xy=False,
2233                       xlims=None, ylims=None, fontsize=12, **kwargs):
2234        r"""
2235        Plot electronic linewidths vs KS energy.
2236
2237        Args:
2238            ax: |matplotlib-Axes| or None if a new figure should be created.
2239            e0: Option used to define the zero of energy in the band structure plot. Possible values:
2240                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (``self.fermie``).
2241                -  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
2242                -  None: Don't shift energies, equivalent to e0=0
2243            function: Apply this function to the values before plotting
2244            exchange_xy: True to exchange x-y axis.
2245            xlims, ylims: Set the data limits for the x-axis or the y-axis. Accept tuple e.g. ``(left, right)``
2246                   or scalar e.g. ``left``. If left (right) is None, default values are used
2247            fontsize: fontsize for titles and legend.
2248
2249        Returns: |matplotlib-Figure|
2250        """
2251        if not self.has_linewidths: return None
2252        ax, fig, plt = get_ax_fig_plt(ax=ax)
2253
2254        xlabel = r"$\epsilon_{KS}\;(eV)$"
2255        if e0 is not None:
2256            xlabel = r"$\epsilon_{KS}-\epsilon_F\;(eV)$"
2257
2258        # DSU sort to get lw(e) with sorted energies.
2259        e0mesh, lws = zip(*sorted(zip(self.eigens.flat, self.linewidths.flat), key=lambda t: t[0]))
2260        e0 = self.get_e0(e0)
2261        e0mesh = np.array(e0mesh) - e0
2262
2263        kw_linestyle = kwargs.pop("linestyle", "o")
2264        #kw_lw = kwargs.pop("lw", 1)
2265        #kw_lw = kwargs.pop("markersize", 5)
2266        kw_color = kwargs.pop("color", "red")
2267        kw_label = kwargs.pop("label", None)
2268
2269        xx, yy = e0mesh, tuple([function(lw) for lw in lws])
2270        if exchange_xy: xx, yy = yy, xx
2271        ax.plot(xx, yy, kw_linestyle, color=kw_color, label=kw_label, **kwargs)
2272        #ax.scatter(xx, yy)
2273
2274        ax.grid(True)
2275        ylabel = "Linewidth (eV)"
2276        if exchange_xy: xlabel, ylabel = ylabel, xlabel
2277        ax.set_ylabel(ylabel)
2278        ax.set_xlabel(xlabel)
2279        set_axlims(ax, xlims, "x")
2280        set_axlims(ax, ylims, "y")
2281        if kw_linestyle:
2282            ax.legend(loc="best", fontsize=fontsize, shadow=True)
2283
2284        return fig
2285
2286    def to_xmgrace(self, filepath):
2287        """
2288        Write xmgrace_ file with band structure energies and labels for high-symmetry k-points.
2289
2290        Args:
2291            filepath: String with filename or stream.
2292        """
2293        is_stream = hasattr(filepath, "write")
2294        if is_stream:
2295            f = filepath
2296        else:
2297            f = open(filepath, "wt")
2298
2299        def w(s):
2300            f.write(s)
2301            f.write("\n")
2302
2303        emef = np.array(self.eigens - self.fermie)
2304
2305        import datetime
2306        w("# Grace project file")
2307        w("# Generated by abipy on: %s" % str(datetime.datetime.today()))
2308        w("# Crystalline structure:")
2309        for s in str(self.structure).splitlines():
2310            w("# %s" % s)
2311        w("# mband: %d, nkpt: %d, nsppol: %d, nspinor: %d" % (self.mband, self.nkpt, self.nsppol, self.nspinor))
2312        w("# nelect: %8.2f, %s" % (self.nelect, str(self.smearing)))
2313        w("# Energies are in eV. Zero set to efermi, previously it was at: %s (eV)" % self.fermie)
2314        w("# List of k-points and their index (C notation i.e. count from 0)")
2315        for ik, kpt in enumerate(self.kpoints):
2316            w("# %d %s" % (ik, str(kpt.frac_coords)))
2317        w("@page size 792, 612")
2318        w("@page scroll 5%")
2319        w("@page inout 5%")
2320        w("@link page off")
2321        w("@with g0")
2322        w("@world xmin 0.00")
2323        w('@world xmax %d' % (self.nkpt - 1))
2324        w('@world ymin %s' % emef.min())
2325        w('@world ymax %s' % emef.max())
2326        w('@default linewidth 1.5')
2327        w('@xaxis  tick on')
2328        w('@xaxis  tick major 1')
2329        w('@xaxis  tick major color 1')
2330        w('@xaxis  tick major linestyle 3')
2331        w('@xaxis  tick major grid on')
2332        w('@xaxis  tick spec type both')
2333        w('@xaxis  tick major 0, 0')
2334
2335        kticks, klabels = self._make_ticks_and_labels(klabels=None)
2336        w('@xaxis  tick spec %d' % len(kticks))
2337        for ik, (ktick, klabel) in enumerate(zip(kticks, klabels)):
2338            w('@xaxis  tick major %d, %d' % (ik, ktick))
2339            w('@xaxis  ticklabel %d, "%s"' % (ik, klabel))
2340
2341        w('@xaxis  ticklabel char size 1.500000')
2342        w('@yaxis  tick major 10')
2343        w('@yaxis  label "Band Energy (eV)"')
2344        w('@yaxis  label char size 1.500000')
2345        w('@yaxis  ticklabel char size 1.500000')
2346        ii = -1
2347        for spin in range(self.nsppol):
2348            for band in range(self.mband):
2349                ii += 1
2350                w('@    s%d line color %d' % (ii, spin + 1))
2351
2352        ii = -1
2353        for spin in range(self.nsppol):
2354            for band in range(self.mband):
2355                ii += 1
2356                w('@target G0.S%d' % ii)
2357                w('@type xy')
2358                for ik in range(self.nkpt):
2359                    w('%d %.8E' % (ik, emef[spin, ik, band]))
2360                w('&')
2361
2362        if not is_stream:
2363            f.close()
2364
2365    def to_bxsf(self, filepath):
2366        """
2367        Export the full band structure to ``filepath`` in BXSF format
2368        suitable for the visualization of isosurfaces with xcrysden_ (xcrysden --bxsf FILE).
2369        Require k-points in IBZ and gamma-centered k-mesh.
2370        """
2371        self.get_ebands3d().to_bxsf(filepath)
2372
2373    def get_ebands3d(self):
2374        return ElectronBands3D(self.structure, self.kpoints, self.has_timrev, self.eigens, self.fermie)
2375
2376    def derivatives(self, spin, band, order=1, acc=4):
2377        """
2378        Compute the derivative of the eigenvalues wrt to k.
2379
2380        Args:
2381            spin: Spin index
2382            band: Band index
2383            order:
2384            acc:
2385
2386        Returns:
2387        """
2388        if self.kpoints.is_path:
2389            # Extract the energy branch.
2390            ebranch = self.eigens[spin, :, band]
2391            # Simulate free-electron bands. This will produce all(effective masses == 1)
2392            #ebranch = 0.5 * units.Ha_to_eV * np.array([(k.norm * units.bohr_to_ang)**2 for k in self.kpoints])
2393
2394            # Compute derivatives by finite differences.
2395            ders_onlines = self.kpoints.finite_diff(ebranch, order=order, acc=acc)
2396            return ders_onlines
2397
2398        else:
2399            raise NotImplementedError("Derivatives on homogeneous k-meshes are not supported yet")
2400
2401    def effective_masses(self, spin, band, acc=4):
2402        """
2403        Compute the effective masses for the given ``spin`` and ``band`` index.
2404        Use finite difference with accuracy ``acc``.
2405
2406        Returns:
2407            |numpy-array| of size self.nkpt with effective masses.
2408        """
2409        ders2 = self.derivatives(spin, band, order=2, acc=acc) * (units.eV_to_Ha / units.bohr_to_ang**2)
2410        return 1. / ders2
2411
2412    def get_effmass_line(self, spin, kpoint, band, acc=4):
2413        """
2414        Compute the effective masses along a k-line. Requires band energies on a k-path.
2415
2416        Args:
2417            spin: Spin index.
2418            kpoint: integer, list of fractional coordinates or |Kpoint| object.
2419            band: Band index.
2420            acc: accuracy
2421        """
2422        warnings.warn("This code is still under development. API may change!")
2423        if not self.kpoints.is_path:
2424            raise ValueError("get_effmass_line requires k-points along a path. Got:\n %s" % repr(self.kpoints))
2425
2426        # We have to understand if the k-point is a vertex or not.
2427        # If it is a vertex, we have to compute the left and right derivative
2428        # If kpt is inside the line, left and right derivatives are supposed to be equal
2429        from abipy.tools.derivatives import finite_diff
2430
2431        for ik in self.kpoints.get_all_kindices(kpoint):
2432            for iline, line in enumerate(self.kpoints.lines):
2433                if line[-1] >= ik >= line[0]: break
2434            else:
2435                raise ValueError("Cannot find k-index `%s` in lines: `%s`" % (ik, self.kpoints.lines))
2436
2437            kpos = line.index(ik)
2438            is_inside = kpos not in (0, len(line) - 1)
2439            do_right = (not is_inside) and kpos != 0 and iline != len(self.kpoints.lines) - 1
2440
2441            evals_on_line, h_left, vers_left = self._eigens_hvers_iline(spin, band, iline)
2442            d2 = finite_diff(evals_on_line, h_left, order=2, acc=acc, index=kpos)
2443            em_left = 1. / (d2.value * (units.eV_to_Ha / units.bohr_to_ang ** 2))
2444            em_right = em_left
2445            h_right, vers_right = h_left, vers_left
2446
2447            if do_right:
2448                kpos_right = self.kpoints.lines[iline + 1].index(ik)
2449                assert kpos_right == 0
2450                evals_on_line, h_right, vers_right = self._eigens_hvers_iline(spin, band, iline + 1)
2451                d2 = finite_diff(evals_on_line, h_right, order=2, acc=acc, index=kpos_right)
2452                em_right = 1. / (d2.value * (units.eV_to_Ha / units.bohr_to_ang ** 2))
2453
2454            lines = []; app = lines.append
2455            app("For spin: %s, band: %s, k-point: %s, eig: %.3f [eV], accuracy: %s" % (
2456                spin, band, repr(self.kpoints[ik]), self.eigens[spin, ik, band], acc))
2457            #app("K-point: %s, eigenvalue: %s (eV)" % (repr(self.kpoint), self.eig))
2458            #app("h_left: %s, h_right %s" % (self.h_left, self.h_right))
2459            #app("is_inside: %s, vers_left: %s, vers_right: %s" % (self.is_inside, self.vers_left, self.vers_right))
2460            if em_left != em_right:
2461                app("emass_left: %.3f, emass_right: %.3f" % (em_left, em_right))
2462            else:
2463                app("emass: %.3f" % em_left)
2464            print("\n".join(lines))
2465
2466    def _eigens_hvers_iline(self, spin, band, iline):
2467        line = self.kpoints.lines[iline]
2468        evals_on_line = self.eigens[spin, line, band]
2469        h = self.kpoints.ds[line[0]]
2470
2471        if not np.allclose(h, self.kpoints.ds[line[:-1]]):
2472            raise ValueError("For finite difference derivatives, the path must be homogeneous!\n" +
2473                             str(self.kpoints.ds[line[:-1]]))
2474
2475        return evals_on_line, h, self.kpoints.versors[line[0]]
2476
2477    def interpolate(self, lpratio=5, knames=None, vertices_names=None, line_density=20,
2478                    kmesh=None, is_shift=None, bstart=0, bstop=None, filter_params=None, verbose=0):
2479        """
2480        Interpolate energies in k-space along a k-path and, optionally, in the IBZ for DOS calculations.
2481        Note that the interpolation will likely fail if there are symmetrical k-points in the input set of k-points
2482        so it's recommended to call this method with energies obtained in the IBZ.
2483
2484        Args:
2485            lpratio: Ratio between the number of star functions and the number of ab-initio k-points.
2486                The default should be OK in many systems, larger values may be required for accurate derivatives.
2487            knames: List of strings with the k-point labels for the k-path. Has precedence over vertices_names.
2488            vertices_names: Used to specify the k-path for the interpolated band structure
2489                It's a list of tuple, each tuple is of the form (kfrac_coords, kname) where
2490                kfrac_coords are the reduced coordinates of the k-point and kname is a string with the name of
2491                the k-point. Each point represents a vertex of the k-path. ``line_density`` defines
2492                the density of the sampling. If None, the k-path is automatically generated according
2493                to the point group of the system.
2494            line_density: Number of points in the smallest segment of the k-path.
2495            kmesh: Used to activate the interpolation on the homogeneous mesh for DOS (uses spglib_ API).
2496                kmesh is given by three integers and specifies mesh numbers along reciprocal primitive axis.
2497            is_shift: three integers (spglib_ API). When is_shift is not None, the kmesh is shifted along
2498                the axis in half of adjacent mesh points irrespective of the mesh numbers. None means unshited mesh.
2499            bstart, bstop: Select the range of band to be used in the interpolation
2500            filter_params: TO BE described.
2501            verbose: Verbosity level
2502
2503        Returns:
2504                namedtuple with the following attributes::
2505
2506                    ebands_kpath: |ElectronBands| with the interpolated band structure on the k-path.
2507                    ebands_kmesh: |ElectronBands| with the interpolated band structure on the k-mesh.
2508                        None if ``kmesh`` is not given.
2509                    interpolator: |SkwInterpolator| object.
2510        """
2511        # Get symmetries from abinit spacegroup (read from file).
2512        abispg = self.structure.abi_spacegroup
2513        if abispg is None:
2514            abispg = self.structure.spgset_abi_spacegroup(has_timerev=self.has_timrev)
2515
2516        fm_symrel = [s for (s, afm) in zip(abispg.symrel, abispg.symafm) if afm == 1]
2517
2518        if self.nband > self.nelect and self.nband > 20 and bstart == 0 and bstop is None:
2519            cprint("Bands object contains nband %s with nelect %s. You may want to use bstart, bstop to select bands." % (
2520                    self.nband, self.nelect), "yellow")
2521
2522        # Build interpolator.
2523        from abipy.core.skw import SkwInterpolator
2524        cell = (self.structure.lattice.matrix, self.structure.frac_coords,
2525                self.structure.atomic_numbers)
2526
2527        skw = SkwInterpolator(lpratio, self.kpoints.frac_coords, self.eigens[:,:,bstart:bstop], self.fermie, self.nelect,
2528                              cell, fm_symrel, self.has_timrev,
2529                              filter_params=filter_params, verbose=verbose)
2530
2531        # Generate k-points for interpolation.
2532        if knames is not None:
2533            kpath = Kpath.from_names(self.structure, knames, line_density=line_density)
2534        else:
2535            if vertices_names is None:
2536                vertices_names = [(k.frac_coords, k.name) for k in self.structure.hsym_kpoints]
2537            kpath = Kpath.from_vertices_and_names(self.structure, vertices_names, line_density=line_density)
2538
2539        # Interpolate energies.
2540        eigens_kpath = skw.interp_kpts(kpath.frac_coords).eigens
2541
2542        # Build new ebands object.
2543        occfacts_kpath = np.zeros_like(eigens_kpath)
2544        ebands_kpath = self.__class__(self.structure, kpath, eigens_kpath, self.fermie, occfacts_kpath,
2545                                      self.nelect, self.nspinor, self.nspden, smearing=self.smearing)
2546        ebands_kmesh = None
2547        if kmesh is not None:
2548            # Get kpts and weights in IBZ.
2549            kdos = Ktables(self.structure, kmesh, is_shift, self.has_timrev)
2550            eigens_kmesh = skw.interp_kpts(kdos.ibz).eigens
2551
2552            # Build new ebands object with k-mesh
2553            #kptopt = kptopt_from_timrev()
2554            ksampling = KSamplingInfo.from_mpdivs(mpdivs=kmesh, shifts=[0, 0, 0], kptopt=1)
2555            kpts_kmesh = IrredZone(self.structure.reciprocal_lattice, kdos.ibz, weights=kdos.weights,
2556                                   names=None, ksampling=ksampling)
2557            occfacts_kmesh = np.zeros_like(eigens_kmesh)
2558
2559            ebands_kmesh = self.__class__(self.structure, kpts_kmesh, eigens_kmesh, self.fermie, occfacts_kmesh,
2560                                          self.nelect, self.nspinor, self.nspden, smearing=self.smearing)
2561
2562        return dict2namedtuple(ebands_kpath=ebands_kpath, ebands_kmesh=ebands_kmesh, interpolator=skw)
2563
2564    def get_collinear_mag(self):
2565        """
2566        Calculates the total collinear magnetization in Bohr magneton as the difference
2567        between the spin up and spin down densities.
2568
2569        Returns:
2570            float: the total magnetization.
2571        """
2572        if self.nsppol == 1:
2573            if self.nspinor == 1 or (self.nspinor == 2 and self.nspden == 1):
2574                return 0
2575            else:
2576                raise ValueError("Cannot calculate collinear magnetization for nsppol: {}, "
2577                                 "nspinor {}, nspden {}".format(self.nsppol, self.nspinor, self.nspden))
2578        else:
2579            rhoup = np.sum(self.kpoints.weights[:, None] * self.occfacts[0])
2580            rhoudown = np.sum(self.kpoints.weights[:, None] * self.occfacts[1])
2581            return rhoup - rhoudown
2582
2583
2584def dataframe_from_ebands(ebands_objects, index=None, with_spglib=True):
2585    """
2586    Build a pandas dataframe with the most important results available in a list of band structures.
2587
2588    Args:
2589        ebands_objects: List of objects that can be converted to structure.
2590            Support netcdf filenames or |ElectronBands| objects
2591            See ``ElectronBands.as_ebands`` for the complete list.
2592        index: Index of the dataframe.
2593        with_spglib: If True, spglib is invoked to get the spacegroup symbol and number.
2594
2595    Return: |pandas-DataFrame|
2596    """
2597    ebands_list = [ElectronBands.as_ebands(obj) for obj in ebands_objects]
2598    # Use OrderedDict to have columns ordered nicely.
2599    odict_list = [(ebands.get_dict4pandas(with_spglib=with_spglib)) for ebands in ebands_list]
2600
2601    return pd.DataFrame(odict_list, index=index,
2602                        columns=list(odict_list[0].keys()) if odict_list else None)
2603
2604
2605class ElectronBandsPlotter(NotebookWriter):
2606    """
2607    Class for plotting electronic band structure and DOSes.
2608    Supports plots on the same graph or separated plots.
2609
2610    Usage example:
2611
2612    .. code-block:: python
2613
2614        plotter = ElectronBandsPlotter()
2615        plotter.add_ebands("foo-label", "foo_GSR.nc")
2616        plotter.add_ebands("bar-label", "bar_WFK.nc")
2617        fig = plotter.gridplot()
2618
2619    Dictionary with the mapping label --> edos.
2620
2621    .. rubric:: Inheritance Diagram
2622    .. inheritance-diagram:: ElectronBandsPlotter
2623    """
2624    # Used in iter_lineopt to generate matplotlib linestyles.
2625    _LINE_COLORS = ["b", "r", "g", "m", "y", "k"]
2626    _LINE_STYLES = ["-", ":", "--", "-.",]
2627    _LINE_WIDTHS = [2,]
2628
2629    def __init__(self, key_ebands=None, key_edos=None, edos_kwargs=None):
2630        """
2631        Args:
2632            key_ebands: List of (label, ebands) tuples.
2633                ebands is any object that can be converted into |ElectronBands| e.g. ncfile, path.
2634            key_edos: List of (label, edos) tuples.
2635                edos is any object that can be converted into |ElectronDos|.
2636        """
2637        if key_ebands is None: key_ebands = []
2638        key_ebands = [(k, ElectronBands.as_ebands(v)) for k, v in key_ebands]
2639        self.ebands_dict = OrderedDict(key_ebands)
2640
2641        if key_edos is None: key_edos = []
2642        key_edos = [(k, ElectronDos.as_edos(v, edos_kwargs)) for k, v in key_edos]
2643        self.edoses_dict = OrderedDict(key_edos)
2644        if key_edos:
2645            if not key_ebands:
2646                raise ValueError("key_ebands must be specifed when key_dos is not None")
2647            if len(key_ebands) != len(key_edos):
2648                raise ValueError("key_ebands and key_edos must have the same number of elements.")
2649
2650    def __repr__(self):
2651        """Invoked by repr"""
2652        return self.to_string(func=repr)
2653
2654    def __str__(self):
2655        """Invoked by str"""
2656        return self.to_string(func=str)
2657
2658    def __len__(self):
2659        return len(self.ebands_dict)
2660
2661    def add_plotter(self, other):
2662        """Merge two plotters, return new plotter."""
2663        if not isinstance(other, self.__class__):
2664            raise TypeError("Don't know to add %s to %s" % (other.__class__, self.__class__))
2665
2666        key_ebands = list(self.ebands_dict.items()) + list(other.ebands_dict.items())
2667        key_edos = list(self.edoses_dict.items()) + list(other.edoses_dict.items())
2668
2669        return self.__class__(key_ebands=key_ebands, key_edos=key_edos)
2670
2671    def to_string(self, func=str, verbose=0):
2672        """String representation."""
2673        lines = []
2674        app = lines.append
2675        for i, (label, ebands) in enumerate(self.ebands_dict.items()):
2676            app("[%d] %s --> %s" % (i, label, func(ebands)))
2677
2678        if self.edoses_dict:
2679            for i, (label, edos) in enumerate(self.edoses_dict.items()):
2680                app("[%d] %s --> %s" % (i, label, func(edos)))
2681
2682        return "\n".join(lines)
2683
2684    def get_ebands_frame(self, with_spglib=True):
2685        """
2686        Build a |pandas-DataFrame| with the most important results available in the band structures.
2687        Useful to analyze band-gaps.
2688        """
2689        return dataframe_from_ebands(list(self.ebands_dict.values()),
2690                                     index=list(self.ebands_dict.keys()), with_spglib=with_spglib)
2691
2692    @property
2693    def ebands_list(self):
2694        """"List of |ElectronBands| objects."""
2695        return list(self.ebands_dict.values())
2696
2697    @property
2698    def edoses_list(self):
2699        """"List of |ElectronDos| objects."""
2700        return list(self.edoses_dict.values())
2701
2702    def iter_lineopt(self):
2703        """Generates matplotlib linestyles."""
2704        for o in itertools.product(self._LINE_WIDTHS,  self._LINE_STYLES, self._LINE_COLORS):
2705            yield {"linewidth": o[0], "linestyle": o[1], "color": o[2]}
2706
2707    def add_ebands(self, label, bands, edos=None, edos_kwargs=None):
2708        """
2709        Adds a band structure and optionally an edos to the plotter.
2710
2711        Args:
2712            label: label for the bands. Must be unique.
2713            bands: |ElectronBands| object.
2714            edos: |ElectronDos| object.
2715            edos_kwargs: optional dictionary with the options passed to ``get_edos`` to compute the electron DOS.
2716                Used only if ``edos`` is not None and it's not an |ElectronDos| instance.
2717        """
2718        if label in self.ebands_dict:
2719            raise ValueError("label %s is already in %s" % (label, list(self.ebands_dict.keys())))
2720
2721        self.ebands_dict[label] = ElectronBands.as_ebands(bands)
2722        if edos is not None:
2723            self.edoses_dict[label] = ElectronDos.as_edos(edos, edos_kwargs)
2724
2725    def bands_statdiff(self, ref=0):
2726        """
2727        Compare the reference bands with index ref with the other bands stored in the plotter.
2728        """
2729        for i, label in enumerate(self.ebands_dict.keys()):
2730            if i == ref:
2731                ref_label = label
2732                break
2733        else:
2734            raise ValueError("ref index %s is > number of bands" % ref)
2735
2736        ref_bands = self.ebands_dict[ref_label]
2737
2738        text = []
2739        for label, bands in self.ebands_dict.items():
2740            if label == ref_label: continue
2741            stat = ref_bands.statdiff(bands)
2742            text.append(str(stat))
2743
2744        return "\n\n".join(text)
2745
2746    def yield_figs(self, **kwargs):  # pragma: no cover
2747        """
2748        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
2749        """
2750        for mname in ("gridplot", "boxplot"):
2751            yield getattr(self, mname)(show=False)
2752
2753    @add_fig_kwargs
2754    def combiplot(self, e0="fermie", ylims=None, width_ratios=(2, 1), fontsize=8,
2755                  linestyle_dict=None, **kwargs):
2756        """
2757        Plot the band structure and the DOS on the same figure.
2758        Use ``gridplot`` to plot band structures on different figures.
2759
2760        Args:
2761            e0: Option used to define the zero of energy in the band structure plot. Possible values::
2762
2763                - `fermie`: shift all eigenvalues to have zero energy at the Fermi energy (ebands.fermie)
2764                   Note that, by default, the Fermi energy is taken from the band structure object
2765                   i.e. the Fermi energy computed at the end of the SCF file that produced the density.
2766                   This should be ok in semiconductors. In metals, however, a better value of the Fermi energy
2767                   can be obtained from the DOS provided that the k-sampling for the DOS is much denser than
2768                   the one used to compute the density. See `edos_fermie`.
2769                - ``edos_fermie``: Use the Fermi energy computed from the DOS to define the zero of energy in both subplots.
2770                   Available only if plotter contains dos objects.
2771                -  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
2772                -  None: Don't shift energies, equivalent to e0=0
2773
2774            ylims: Set the data limits for the y-axis. Accept tuple e.g. `(left, right)`
2775                   or scalar e.g. `left`. If left (right) is None, default values are used
2776            width_ratios: Defines the ratio between the band structure plot and the dos plot.
2777                Used when there are DOS stored in the plotter.
2778            fontsize: fontsize for titles and legend.
2779            linestyle_dict: Dictionary mapping labels to matplotlib linestyle options.
2780
2781        Returns: |matplotlib-Figure|.
2782        """
2783        import matplotlib.pyplot as plt
2784        from matplotlib.gridspec import GridSpec
2785        fig = plt.figure()
2786
2787        if self.edoses_dict:
2788            # Build grid with two axes.
2789            gspec = GridSpec(nrows=1, ncols=2, width_ratios=width_ratios, wspace=0.05)
2790            # bands and DOS will share the y-axis
2791            ax0 = plt.subplot(gspec[0])
2792            ax1 = plt.subplot(gspec[1], sharey=ax0)
2793            ax_list = [ax0, ax1]
2794        else:
2795            # One axis for bands only
2796            ax0 = fig.add_subplot(111)
2797            ax_list = [ax0]
2798
2799        for ax in ax_list:
2800            ax.grid(True)
2801            set_axlims(ax, ylims, "y")
2802
2803        # Plot ebands.
2804        lines, legends = [], []
2805        my_kwargs, opts_label = kwargs.copy(), {}
2806        i = -1
2807        nkpt_list = [ebands.nkpt for ebands in self.ebands_dict.values()]
2808        if any(nk != nkpt_list[0] for nk in nkpt_list):
2809            cprint("WARNING: Bands have different number of k-points:\n%s" % str(nkpt_list), "yellow")
2810
2811        for (label, ebands), lineopt in zip(self.ebands_dict.items(), self.iter_lineopt()):
2812            i += 1
2813            if linestyle_dict is not None and label in linestyle_dict:
2814                my_kwargs.update(linestyle_dict[label])
2815            else:
2816                my_kwargs.update(lineopt)
2817
2818            opts_label[label] = my_kwargs.copy()
2819
2820            # Get energy zero.
2821            mye0 = self.edoses_dict[label].fermie if e0 == "edos_fermie" else ebands.get_e0(e0)
2822
2823            l = ebands.plot_ax(ax0, mye0, spin=None, band=None, **my_kwargs)
2824            lines.append(l[0])
2825
2826            # Use relative paths if label is a file.
2827            if os.path.isfile(label):
2828                legends.append("%s" % os.path.relpath(label))
2829            else:
2830                legends.append("%s" % label)
2831
2832            # Set ticks and labels, legends.
2833            if i == 0:
2834                ebands.decorate_ax(ax0)
2835
2836        ax0.legend(lines, legends, loc='upper right', fontsize=fontsize, shadow=True)
2837
2838        # Add DOSes
2839        if self.edoses_dict:
2840            ax = ax_list[1]
2841            for label, edos in self.edoses_dict.items():
2842                ebands = self.edoses_dict[label]
2843                mye0 = ebands.get_e0(e0) if e0 != "edos_fermie" else edos.fermie
2844                edos.plot_ax(ax, mye0, exchange_xy=True, **opts_label[label])
2845
2846        return fig
2847
2848    def plot(self, *args, **kwargs):
2849        """An alias for combiplot."""
2850        if "align" in kwargs or "xlim" in kwargs or "ylim" in kwargs:
2851            raise ValueError("align|xlim|ylim options are not supported anymore.")
2852        return self.combiplot(*args, **kwargs)
2853
2854    @add_fig_kwargs
2855    def gridplot(self, e0="fermie", with_dos=True, with_gaps=False, max_phfreq=None,
2856                 ylims=None, fontsize=8, **kwargs):
2857        """
2858        Plot multiple electron bandstructures and optionally DOSes on a grid.
2859
2860        Args:
2861            eb_objects: List of objects from which the band structures are extracted.
2862                Each item in eb_objects is either a string with the path of the netcdf file,
2863                or one of the abipy object with an ``ebands`` attribute or a |ElectronBands| object.
2864            edos_objects: List of objects from which the electron DOSes are extracted.
2865                Accept filepaths or |ElectronDos| objects. If edos_objects is not None,
2866                each subplot in the grid contains a band structure with DOS else a simple bandstructure plot.
2867            e0: Option used to define the zero of energy in the band structure plot. Possible values::
2868
2869                - ``fermie``: shift all eigenvalues and the DOS to have zero energy at the Fermi energy.
2870                   Note that, by default, the Fermi energy is taken from the band structure object
2871                   i.e. the Fermi energy computed at the end of the SCF file that produced the density.
2872                   This should be ok in semiconductors. In metals, however, a better value of the Fermi energy
2873                   can be obtained from the DOS provided that the k-sampling for the DOS is much denser than
2874                   the one used to compute the density. See `edos_fermie`.
2875                - ``edos_fermie``: Use the Fermi energy computed from the DOS to define the zero of energy in both subplots.
2876                   Available only if edos_objects is not None
2877                -  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
2878                -  None: Don't shift energies, equivalent to e0=0
2879
2880            with_dos: True if DOS should be printed.
2881            with_gaps: True to add markesr and arrows showing the fundamental and the direct gap.
2882            max_phfreq: Max phonon frequency in eV to activate scatterplot showing
2883                possible phonon absorptions/emission processes based on energy-conservation alone.
2884                All final states whose energy is within +- max_phfreq of the initial state are included.
2885                By default, the four electronic states defining the fundamental and the direct gaps
2886                are considered as initial state (not available for metals).
2887            ylims: Set the data limits for the y-axis. Accept tuple e.g. ```(left, right)``
2888                   or scalar e.g. ``left``. If left (right) is None, default values are used
2889            fontsize: fontsize for titles and legend.
2890
2891        Returns: |matplotlib-Figure|
2892        """
2893        titles = list(self.ebands_dict.keys())
2894        ebands_list, edos_list = self.ebands_list, self.edoses_list
2895
2896        import matplotlib.pyplot as plt
2897        nrows, ncols = 1, 1
2898        numeb = len(ebands_list)
2899        if numeb > 1:
2900            ncols = 2
2901            nrows = numeb // ncols + numeb % ncols
2902
2903        if not edos_list or not with_dos:
2904            # Plot grid with bands only.
2905            fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, sharey=True, squeeze=False)
2906            ax_list = ax_list.ravel()
2907            # don't show the last ax if numeb is odd.
2908            if numeb % ncols != 0: ax_list[-1].axis("off")
2909
2910            for i, (ebands, ax) in enumerate(zip(ebands_list, ax_list)):
2911                irow, icol = divmod(i, ncols)
2912                ebands.plot(ax=ax, e0=e0, with_gaps=with_gaps, max_phfreq=max_phfreq, show=False)
2913                set_axlims(ax, ylims, "y")
2914                # This to handle with_gaps = True
2915                title = ax.get_title()
2916                if not title: ax.set_title(titles[i], fontsize=fontsize)
2917                if (irow, icol) != (0, 0):
2918                    set_visible(ax, False, "ylabel")
2919
2920        else:
2921            # Plot grid with bands + DOS. see http://matplotlib.org/users/gridspec.html
2922            from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
2923            fig = plt.figure()
2924            gspec = GridSpec(nrows, ncols)
2925
2926            for i, (ebands, edos) in enumerate(zip(ebands_list, edos_list)):
2927                subgrid = GridSpecFromSubplotSpec(1, 2, subplot_spec=gspec[i], width_ratios=[2, 1], wspace=0.05)
2928                # Get axes and align bands and DOS.
2929                ax0 = plt.subplot(subgrid[0])
2930                ax1 = plt.subplot(subgrid[1], sharey=ax0)
2931                set_axlims(ax0, ylims, "y")
2932                set_axlims(ax1, ylims, "y")
2933
2934                # Define the zero of energy and plot
2935                mye0 = ebands.get_e0(e0) if e0 != "edos_fermie" else edos.fermie
2936                ebands.plot_with_edos(edos, e0=mye0, ax_list=(ax0, ax1), with_gaps=with_gaps,
2937                                      max_phfreq=max_phfreq, show=False)
2938
2939                # This to handle with_gaps = True
2940                title = ax0.get_title()
2941                if not title: ax0.set_title(titles[i], fontsize=fontsize)
2942                if i % ncols != 0:
2943                    for ax in (ax0, ax1):
2944                        ax.set_ylabel("")
2945
2946        return fig
2947
2948    @add_fig_kwargs
2949    def boxplot(self, e0="fermie", brange=None, swarm=False, fontsize=8, **kwargs):
2950        """
2951        Use seaborn_ to draw a box plot to show distributions of eigenvalues with respect to the band index.
2952        Band structures are drawn on different subplots.
2953
2954        Args:
2955            e0: Option used to define the zero of energy in the band structure plot. Possible values:
2956                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (`self.fermie`).
2957                -  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
2958                -  None: Don't shift energies, equivalent to e0=0
2959            brange: Only bands such as ``brange[0] <= band_index < brange[1]`` are included in the plot.
2960            swarm: True to show the datapoints on top of the boxes
2961            fontsize: Fontsize for title.
2962            kwargs: Keyword arguments passed to seaborn boxplot.
2963        """
2964        # Build grid of plots.
2965        num_plots, ncols, nrows = len(self.ebands_dict), 1, 1
2966        if num_plots > 1:
2967            ncols = 2
2968            nrows = (num_plots // ncols) + (num_plots % ncols)
2969
2970        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
2971                                                sharex=False, sharey=True, squeeze=False)
2972        ax_list = ax_list.ravel()
2973        # don't show the last ax if numeb is odd.
2974        if num_plots % ncols != 0: ax_list[-1].axis("off")
2975
2976        for (label, ebands), ax in zip(self.ebands_dict.items(), ax_list):
2977            ebands.boxplot(ax=ax, brange=brange, show=False)
2978            ax.set_title(label, fontsize=fontsize)
2979
2980        return fig
2981
2982    @add_fig_kwargs
2983    def combiboxplot(self, e0="fermie", brange=None, swarm=False, ax=None, **kwargs):
2984        """
2985        Use seaborn_ to draw a box plot comparing the distributions of the eigenvalues
2986        Band structures are drawn on the same plot.
2987
2988        Args:
2989            e0: Option used to define the zero of energy in the band structure plot. Possible values:
2990                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (`self.fermie`).
2991                -  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
2992                -  None: Don't shift energies, equivalent to e0=0
2993
2994            brange: Only bands such as ``brange[0] <= band_index < brange[1]`` are included in the plot.
2995            swarm: True to show the datapoints on top of the boxes
2996            ax: |matplotlib-Axes| or None if a new figure should be created.
2997            kwargs: Keyword arguments passed to seaborn boxplot.
2998        """
2999        spin_polarized = False
3000        frames = []
3001        for label, ebands in self.ebands_dict.items():
3002            # Get the dataframe, select bands and add column with label
3003            frame = ebands.get_dataframe(e0=e0)
3004            if brange is not None:
3005                frame = frame[(frame["band"] >= brange[0]) & (frame["band"] < brange[1])]
3006            frame["label"] = label
3007            frames.append(frame)
3008            if ebands.nsppol == 2: spin_polarized = True
3009
3010        # Merge frames ignoring index (not meaningful)
3011        data = pd.concat(frames, ignore_index=True)
3012
3013        import seaborn as sns
3014        if not spin_polarized:
3015            ax, fig, plt = get_ax_fig_plt(ax=ax)
3016            ax.grid(True)
3017            sns.boxplot(x="band", y="eig", data=data, hue="label", ax=ax, **kwargs)
3018            if swarm:
3019                sns.swarmplot(x="band", y="eig", data=data, hue="label", color=".25", ax=ax)
3020        else:
3021            # Generate two subplots for spin-up / spin-down channels.
3022            import matplotlib.pyplot as plt
3023            if ax is not None:
3024                raise NotImplementedError("ax == None not implemented when nsppol==2")
3025            fig, ax_list = plt.subplots(nrows=2, ncols=1, sharex=True, squeeze=False)
3026            for spin, ax in zip(range(2), ax_list.ravel()):
3027                ax.grid(True)
3028                data_spin = data[data["spin"] == spin]
3029                sns.boxplot(x="band", y="eig", data=data_spin, hue="label", ax=ax, **kwargs)
3030                if swarm:
3031                    sns.swarmplot(x="band", y="eig", data=data_spin, hue="label", color=".25", ax=ax)
3032
3033        return fig
3034
3035    @add_fig_kwargs
3036    def plot_band_edges(self, e0="fermie", epad_ev=1.0, set_fermie_to_vbm=True, colormap="viridis", fontsize=8, **kwargs):
3037        """
3038        Plot the band edges for electrons and holes on two separated plots for all ebands in ebands_dict.
3039        Useful for comparing band structures obtained with/without SOC or bands obtained with different settings.
3040
3041        Args:
3042            e0: Option used to define the zero of energy in the band structure plot. Possible values:
3043                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (`self.fermie`).
3044                -  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
3045                -  None: Don't shift energies, equivalent to e0=0
3046            epad_ev: Add this energy window in eV above VBM and below CBM.
3047            set_fermie_to_vbm: True if Fermi energy should be recomputed and fixed at max occupied energy level.
3048            colormap: matplotlib colormap.
3049            fontsize: legend and title fontsize.
3050        """
3051        # Two subplots for CBM and VBM
3052        num_plots, ncols, nrows = 2, 1, 2
3053        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
3054                                                sharex=False, sharey=False, squeeze=False)
3055        ax_list = ax_list.ravel()
3056        cmap = plt.get_cmap(colormap)
3057        nb = len(self.ebands_dict.items())
3058
3059        for ix, ax in enumerate(ax_list):
3060            for iband, (label, ebands) in enumerate(self.ebands_dict.items()):
3061                if set_fermie_to_vbm:
3062                    # This is needed when the fermi energy is computed in the GS part
3063                    # with a mesh that does not contain the band edges.
3064                    ebands.set_fermie_to_vbm()
3065
3066                if ix == 0:
3067                    # Conduction
3068                    ymin = min((ebands.lumos[spin].eig for spin in ebands.spins)) - 0.1
3069                    ymax = ymin + epad_ev
3070                elif ix == 1:
3071                    # Valence
3072                    ymax = max((ebands.homos[spin].eig for spin in ebands.spins)) + 0.1
3073                    ymin = ymax - epad_ev
3074                else:
3075                    raise ValueError("Wrong ix: %s" % ix)
3076
3077                # Defin ylims and energy shift.
3078                this_e0 = ebands.get_e0(e0)
3079                ylims = (ymin - this_e0, ymax - this_e0)
3080                ebands.plot(ax=ax, e0=e0, color=cmap(float(iband) / nb), ylims=ylims,
3081                            label=label if ix == 0 else None, show=False)
3082            if ix == 0:
3083                ax.legend(loc="best", fontsize=fontsize, shadow=True)
3084
3085        return fig
3086
3087    def animate(self, e0="fermie", interval=500, savefile=None, width_ratios=(2, 1), show=True):
3088        """
3089        Use matplotlib_ to animate a list of band structure plots (with or without DOS).
3090
3091        Args:
3092            e0: Option used to define the zero of energy in the band structure plot. Possible values::
3093
3094                * ``fermie``: shift all eigenvalues and the DOS to have zero energy at the Fermi energy.
3095                   Note that, by default, the Fermi energy is taken from the band structure object
3096                   i.e. the Fermi energy computed at the end of the SCF file that produced the density.
3097                   See `edos_fermie`.
3098                * ``edos_fermie``: Use the Fermi energy computed from the DOS to define the zero of energy in both subplots.
3099                *  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
3100                *  None: Don't shift energies, equivalent to e0=0
3101
3102            interval: draws a new frame every interval milliseconds.
3103            savefile: Use e.g. 'myanimation.mp4' to save the animation in mp4 format.
3104            width_ratios: Defines the ratio between the band structure plot and the dos plot.
3105                Used when there are DOS stored in the plotter.
3106            show: True if the animation should be shown immediately
3107
3108        Returns: Animation object.
3109
3110        .. See also::
3111
3112            http://matplotlib.org/api/animation_api.html
3113            http://jakevdp.github.io/blog/2012/08/18/matplotlib-animation-tutorial/
3114
3115        .. Note::
3116
3117            It would be nice to animate the title of the plot, unfortunately
3118            this feature is not available in the present version of matplotlib.
3119            See: http://stackoverflow.com/questions/17558096/animated-title-in-matplotlib
3120        """
3121        ebands_list, edos_list = self.ebands_list, self.edoses_list
3122        if edos_list and len(edos_list) != len(ebands_list):
3123            raise ValueError("The number of objects for DOS must be equal to the number of bands")
3124        #titles = list(self.ebands_dict.keys())
3125
3126        import matplotlib.pyplot as plt
3127        fig = plt.figure()
3128        plotax_kwargs = {"color": "black", "linewidth": 2.0}
3129
3130        artists = []
3131        if not edos_list:
3132            # Animation with band structures
3133            ax = fig.add_subplot(1, 1, 1)
3134            ebands_list[0].decorate_ax(ax)
3135            for i, ebands in enumerate(ebands_list):
3136                lines = ebands.plot_ax(ax, e0, **plotax_kwargs)
3137                #if titles is not None: lines += [ax.set_title(titles[i])]
3138                artists.append(lines)
3139        else:
3140            # Animation with band structures + DOS.
3141            from matplotlib.gridspec import GridSpec
3142            gspec = GridSpec(nrows=1, ncols=2, width_ratios=width_ratios, wspace=0.05)
3143            ax0 = plt.subplot(gspec[0])
3144            ax1 = plt.subplot(gspec[1], sharey=ax0)
3145            ebands_list[0].decorate_ax(ax0)
3146            ax1.grid(True)
3147            ax1.yaxis.set_ticks_position("right")
3148            ax1.yaxis.set_label_position("right")
3149
3150            for i, (ebands, edos) in enumerate(zip(ebands_list, edos_list)):
3151                # Define the zero of energy to align bands and dos
3152                mye0 = ebands.get_e0(e0) if e0 != "edos_fermie" else edos.fermie
3153                ebands_lines = ebands.plot_ax(ax0, mye0, **plotax_kwargs)
3154                edos_lines = edos.plot_ax(ax1, mye0, exchange_xy=True, **plotax_kwargs)
3155                lines = ebands_lines + edos_lines
3156                #if titles is not None: lines += [ax.set_title(titles[i])]
3157                artists.append(lines)
3158
3159        import matplotlib.animation as animation
3160        anim = animation.ArtistAnimation(fig, artists, interval=interval,
3161                                         blit=False, # True is faster but then the movie starts with an empty frame!
3162                                         #repeat_delay=1000
3163                                         )
3164
3165        if savefile is not None: anim.save(savefile)
3166        if show: plt.show()
3167
3168        return anim
3169
3170    def _repr_html_(self):
3171        """Integration with jupyter_ notebooks."""
3172        return self.ipw_select_plot()
3173
3174    def ipw_select_plot(self): # pragma: no cover
3175        """
3176        Return an ipython widget with controllers to select the plot.
3177        """
3178        def plot_callback(plot_type, e0):
3179            r = getattr(self, plot_type)(e0=e0, show=True)
3180            if plot_type == "animate": return r
3181
3182        import ipywidgets as ipw
3183        return ipw.interact_manual(
3184                plot_callback,
3185                plot_type=["combiplot", "gridplot", "boxplot", "combiboxplot", "animate"],
3186                e0=["fermie", "0.0"],
3187            )
3188
3189    def write_notebook(self, nbpath=None):
3190        """
3191        Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current
3192        working directory is created. Return path to the notebook.
3193        """
3194        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
3195
3196        # Use pickle file for data persistence.
3197        tmpfile = self.pickle_dump()
3198
3199        nb.cells.extend([
3200            #nbv.new_markdown_cell("# This is a markdown cell"),
3201            nbv.new_code_cell("plotter = abilab.ElectronBandsPlotter.pickle_load('%s')" % tmpfile),
3202            nbv.new_code_cell("print(plotter)"),
3203            nbv.new_code_cell("frame = plotter.get_ebands_frame()\ndisplay(frame)"),
3204            nbv.new_code_cell("ylims = (None, None)"),
3205            nbv.new_code_cell("plotter.gridplot(ylims=ylims);"),
3206            nbv.new_code_cell("plotter.combiplot(ylims=ylims);"),
3207            nbv.new_code_cell("plotter.boxplot();"),
3208            nbv.new_code_cell("plotter.combiboxplot();"),
3209            nbv.new_code_cell("if False: anim = plotter.animate()"),
3210        ])
3211
3212        return self._write_nb_nbpath(nb, nbpath)
3213
3214    #def _can_use_basenames_as_labels(self):
3215    #    """
3216    #    Return True if all labels represent valid files and the basenames are unique
3217    #    In this case one can use the file basename instead of the full path in the plots.
3218    #    """
3219    #    if not all(os.path.exists(l) for l in self.ebands_dict): return False
3220    #    labels = [os.path.basename(l) for l in self.ebands_dict]
3221    #    return len(set(labels)) == len(labels)
3222
3223
3224class ElectronsReader(ETSF_Reader, KpointsReaderMixin):
3225    """
3226    This object reads band structure data from a netcdf_ file.
3227
3228    .. rubric:: Inheritance Diagram
3229    .. inheritance-diagram:: ElectronReader
3230    """
3231    def read_ebands(self):
3232        """
3233        Returns an instance of |ElectronBands|. Main entry point for client code
3234        """
3235        ebands = ElectronBands(
3236            structure=self.read_structure(),
3237            kpoints=self.read_kpoints(),
3238            eigens=self.read_eigenvalues(),
3239            fermie=self.read_fermie(),
3240            occfacts=self.read_occupations(),
3241            nelect=self.read_nelect(),
3242            nspinor=self.read_nspinor(),
3243            nspden=self.read_nspden(),
3244            nband_sk=self.read_nband_sk(),
3245            smearing=self.read_smearing(),
3246            )
3247
3248        # This is to solve the typical problem in semiconductors that shows up
3249        # when the Fermi level from the GS run computed with a shifted k-mesh
3250        # underestimates the CBM at Gamma.
3251        #if ebands.nsppol == 1 and ebands.nspden == 1 and
3252        if ebands.smearing.occopt == 1:
3253            ebands.set_fermie_to_vbm()
3254
3255        return ebands
3256
3257    def read_nband_sk(self):
3258        """|numpy-array| with the number of bands indexed by [s, k]."""
3259        return self.read_value("number_of_states")
3260
3261    def read_nspinor(self):
3262        """Number of spinors."""
3263        return self.read_dimvalue("number_of_spinor_components")
3264
3265    def read_nsppol(self):
3266        """Number of independent spins (collinear case)."""
3267        return self.read_dimvalue("number_of_spins")
3268
3269    def read_nspden(self):
3270        """Number of spin-density components"""
3271        # FIXME: default 1 is needed for SIGRES files (abinit8)
3272        return self.read_dimvalue("number_of_components", default=1)
3273
3274    def read_tsmear(self):
3275        return self.read_value("smearing_width")
3276
3277    def read_eigenvalues(self):
3278        """Eigenvalues in eV."""
3279        return units.ArrayWithUnit(self.read_value("eigenvalues"), "Ha").to("eV")
3280
3281    def read_occupations(self):
3282        """Occupancies."""
3283        return self.read_value("occupations")
3284
3285    def read_fermie(self):
3286        """Fermi level in eV."""
3287        return units.Energy(self.read_value("fermi_energy"), "Ha").to("eV")
3288
3289    def read_nelect(self):
3290        """Number of valence electrons."""
3291        return self.read_value("number_of_electrons")
3292
3293    def read_smearing(self):
3294        """Returns a :class:`Smearing` instance with info on the smearing technique."""
3295        occopt = int(self.read_value("occopt"))
3296        scheme = self.read_string("smearing_scheme")
3297
3298        return Smearing(
3299            scheme=scheme,
3300            occopt=occopt,
3301            tsmear_ev=units.Energy(self.read_value("smearing_width"), "Ha").to("eV")
3302        )
3303
3304
3305class ElectronDos(object):
3306    """
3307    This object stores the electronic density of states.
3308    It is usually created by calling the get_edos method of |ElectronBands|.
3309    """
3310
3311    def __init__(self, mesh, spin_dos, nelect, fermie=None, spin_idos=None):
3312        """
3313        Args:
3314            mesh: array-like object with the mesh points in eV.
3315            spin_dos: array-like object with the DOS for the different spins (even if spin-unpolarized calculation).
3316                Shape is:
3317                      (1, nw) if spin-unpolarized.
3318                      (2, nw) if spin-polarized.
3319            nelect: Number of electrons in the unit cell.
3320            fermie: Fermi level in eV. If None, fermie is obtained from the idos integral.
3321            spin_idos: array-like object with the IDOS for the different spins (even if spin-unpolarized calculation).
3322                Shape is:
3323                      (1, nw) if spin-unpolarized.
3324                      (2, nw) if spin-polarized case.
3325
3326                This argument is usually used when we have an IDOS computed with a more accurate method e.g.
3327                tetrahedron integration so that we can use these values instead of integrating the input DOS.
3328
3329        .. note::
3330
3331            mesh is given in eV, spin_dos is in states/eV.
3332        """
3333        spin_dos = np.atleast_2d(spin_dos)
3334        self.nsppol = len(spin_dos)
3335        self.nelect = nelect
3336        if spin_idos is not None:
3337            spin_idos = np.atleast_2d(spin_idos)
3338            assert len(spin_idos) == self.nsppol
3339
3340        # Save DOS and IDOS for each spin.
3341        sumv = np.zeros(len(mesh))
3342        self.spin_dos, self.spin_idos = [], []
3343        for ispin, values in enumerate(spin_dos):
3344            sumv += values
3345            f = Function1D(mesh, values)
3346            self.spin_dos.append(f)
3347            # Compute IDOS or take it from spin_idos.
3348            if spin_idos is None:
3349                self.spin_idos.append(f.integral())
3350            else:
3351                self.spin_idos.append(Function1D(mesh, spin_idos[ispin]))
3352
3353        # Total DOS and IDOS.
3354        if self.nsppol == 1: sumv = 2 * sumv
3355        self.tot_dos = Function1D(mesh, sumv)
3356        if spin_idos is None:
3357            # Compute IDOS from DOS
3358            self.tot_idos = self.tot_dos.integral()
3359        else:
3360            # Get IDOS from input (e.g. tetra)
3361            if self.nsppol == 1: sumv = 2 * spin_idos[0]
3362            if self.nsppol == 2: sumv = spin_idos[0] + spin_idos[1]
3363            self.tot_idos = Function1D(mesh, sumv)
3364
3365        if fermie is not None:
3366            self.fermie = float(fermie)
3367        else:
3368            # *Compute* fermie from nelect. Note that this value may differ
3369            # from the one stored in ElectronBands (coming from the SCF run)
3370            # The accuracy of self.fermie depends on the number of k-points used for the DOS
3371            # and the parameters used to call ebands.get_edos.
3372            try:
3373                self.fermie = self.find_mu(self.nelect)
3374            except ValueError:
3375                print("tot_idos values:\n", self.tot_idos)
3376                raise
3377
3378    def __str__(self):
3379        return self.to_string()
3380
3381    def to_string(self, verbose=0):
3382        """String representation."""
3383        lines = []; app = lines.append
3384        app("nsppol: %d, nelect: %s" % (self.nsppol, self.nelect))
3385        app("Fermi energy: %s (eV) (recomputed from nelect):" % self.fermie)
3386        return "\n".join(lines)
3387
3388    @classmethod
3389    def as_edos(cls, obj, edos_kwargs):
3390        """
3391        Return an instance of |ElectronDos| from a generic object ``obj``.
3392        Supports:
3393
3394            - instances of cls
3395            - files (string) that can be open with abiopen and that provide an `ebands` attribute.
3396            - objects providing an `ebands` or `get_edos` attribute
3397
3398        Args:
3399            edos_kwargs: optional dictionary with the options passed to `get_edos` to compute the electron DOS.
3400            Used when obj is not already an instance of ``cls``.
3401        """
3402        if edos_kwargs is None: edos_kwargs = {}
3403        if isinstance(obj, cls):
3404            return obj
3405        elif is_string(obj):
3406            # path?
3407            if obj.endswith(".pickle"):
3408                with open(obj, "rb") as fh:
3409                    return cls.as_edos(pickle.load(fh), edos_kwargs)
3410
3411            from abipy.abilab import abiopen
3412            with abiopen(obj) as abifile:
3413                if hasattr(abifile, "ebands"):
3414                    return abifile.ebands.get_edos(**edos_kwargs)
3415                elif hasattr(abifile, "edos"):
3416                    # This to handle e.g. the _EDOS file.
3417                    return abifile.edos
3418                else:
3419                    raise TypeError("Don't know how to extract ElectronDos object from: `%s`" % str(obj))
3420
3421        elif hasattr(obj, "ebands"):
3422            return obj.ebands.get_edos(**edos_kwargs)
3423
3424        elif hasattr(obj, "get_edos"):
3425            return obj.get_edos(**edos_kwargs)
3426
3427        raise TypeError("Don't know how to create `ElectronDos` from %s" % type(obj))
3428
3429    def __eq__(self, other):
3430        if other is None: return False
3431        if self.nsppol != other.nsppol: return False
3432        for f1, f2 in zip(self.spin_dos, other.spin_dos):
3433            if f1 != f2: return False
3434        return True
3435
3436    def __ne__(self, other):
3437        return not (self == other)
3438
3439    def dos_idos(self, spin=None):
3440        """
3441        Returns DOS and IDOS for given spin. Total DOS and IDOS if spin is None.
3442        """
3443        if spin is None:
3444            return self.tot_dos, self.tot_idos
3445        else:
3446            return self.spin_dos[spin], self.spin_idos[spin]
3447
3448    def find_mu(self, nelect, spin=None):
3449        """
3450        Finds the chemical potential given the number of electrons.
3451        """
3452        idos = self.tot_idos if spin is None else self.spin_idos[spin]
3453
3454        # Cannot use bisection because DOS might be negative due to smearing.
3455        # This one is safer albeit slower.
3456        for i, (ene, intg) in enumerate(idos):
3457            if intg > nelect: break
3458        else:
3459            # If the mesh is not large enough, we never cross nelect
3460            # If the last point in IDOS is sufficiently close to nelect
3461            # use it as Fermi level.
3462            if abs(idos.values[-1] - nelect) < 1e-3:
3463                i = len(idos) - 1
3464            else:
3465                raise ValueError("Cannot find I(e) such that I(e) > nelect")
3466
3467        # Use linear interpolation to find mu (useful if mesh is coarse)
3468        e0, y0 = idos[i-1]
3469        e1, y1 = idos[i]
3470
3471        alpha = (y1 - y0) / (e1 - e0)
3472        beta = y0 - alpha * e0
3473        mu = (nelect - beta) / alpha
3474        #print("idos[i-1]:", idos[i-1], "idos[i]:", idos[i], "intg", intg, "nelect", nelect)
3475        #print("mu linear", mu)
3476        return mu
3477
3478    @lazy_property
3479    def up_minus_down(self):
3480        """
3481        Function1D with dos_up - dos_down
3482        """
3483        if self.nsppol == 1: # DOH!
3484            return Function1D.from_constant(self.spin_dos[0].mesh, 0.0)
3485        else:
3486            return self.spin_dos[0] - self.spin_dos[1]
3487
3488    def get_e0(self, e0):
3489        """
3490        e0: Option used to define the zero of energy in the band structure plot. Possible values:
3491                - `fermie`: shift all eigenvalues to have zero energy at the Fermi energy (`self.fermie`).
3492                -  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
3493                -  None: Don't shift energies, equivalent to e0=0
3494        """
3495        if e0 is None:
3496            return 0.0
3497
3498        elif is_string(e0):
3499            if e0 == "fermie":
3500                return self.fermie
3501            elif e0 == "None":
3502                return 0.0
3503            else:
3504                try:
3505                    return float(e0)
3506                except Exception:
3507                    raise TypeError("Wrong value for e0: %s" % str(e0))
3508        else:
3509            # Assume number
3510            return float(e0)
3511
3512    def plot_ax(self, ax, e0, spin=None, what="dos", fact=1.0, exchange_xy=False, **kwargs):
3513        """
3514        Helper function to plot the DOS data on the axis ``ax``.
3515
3516        Args:
3517            ax: |matplotlib-Axes|.
3518            e0: Option used to define the zero of energy in the band structure plot.
3519            spin: selects the spin component, None for total DOS, IDOS.
3520            what: string selecting what will be plotted. "dos" for DOS, "idos" for IDOS
3521            fact: Multiplication factor for DOS/IDOS. Usually +-1 for spin DOS
3522            exchange_xy: True to exchange x-y axis.
3523            kwargs: Options passed to matplotlib ``ax.plot``
3524
3525        Return: list of lines added to the axis ax.
3526        """
3527        dosf, idosf = self.dos_idos(spin=spin)
3528        e0 = self.get_e0(e0)
3529
3530        w2f = {"dos": dosf, "idos": idosf}
3531        if what not in w2f:
3532            raise ValueError("Unknown value for what: `%s`" % str(what))
3533        f = w2f[what]
3534
3535        xx, yy = f.mesh - e0, f.values * fact
3536        if exchange_xy: xx, yy = yy, xx
3537        lines = []
3538        lines.extend(ax.plot(xx, yy, **kwargs))
3539
3540        return lines
3541
3542    @add_fig_kwargs
3543    def plot(self, e0="fermie", spin=None, ax=None, exchange_xy=False, xlims=None, ylims=None, **kwargs):
3544        """
3545        Plot electronic DOS
3546
3547        Args:
3548            e0: Option used to define the zero of energy in the band structure plot. Possible values:
3549                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (``self.fermie``).
3550                - Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV
3551                - None: Don't shift energies, equivalent to ``e0 = 0``.
3552            spin: Selects the spin component, None if total DOS is wanted.
3553            ax: |matplotlib-Axes| or None if a new figure should be created.
3554            exchange_xy: True to exchange x-y axis.
3555            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
3556                or scalar e.g. ``left``. If left (right) is None, default values are used
3557            ylims: Set data limits for the y-axis.
3558            kwargs: options passed to ``ax.plot``.
3559
3560        Return: |matplotlib-Figure|
3561        """
3562        ax, fig, plt = get_ax_fig_plt(ax=ax)
3563        e0 = self.get_e0(e0)
3564
3565        for spin in range(self.nsppol):
3566            opts = {"color": "black", "linewidth": 1.0} if spin == 0 else \
3567                   {"color": "red", "linewidth": 1.0}
3568            opts.update(kwargs)
3569            spin_sign = +1 if spin == 0 else -1
3570            x, y = self.spin_dos[spin].mesh - e0, spin_sign * self.spin_dos[spin].values
3571            if exchange_xy: x, y = y, x
3572            ax.plot(x, y, **opts)
3573
3574        ax.grid(True)
3575        xlabel, ylabel = 'Energy (eV)', 'DOS (states/eV)'
3576        set_ax_xylabels(ax, xlabel, ylabel, exchange_xy)
3577        set_axlims(ax, xlims, "x")
3578        set_axlims(ax, ylims, "y")
3579
3580        return fig
3581
3582    @add_fig_kwargs
3583    def plot_dos_idos(self, e0="fermie", ax_list=None, xlims=None, height_ratios=(1, 2), **kwargs):
3584        """
3585        Plot electronic DOS and Integrated DOS on two different subplots.
3586
3587        Args:
3588            e0: Option used to define the zero of energy in the band structure plot. Possible values:
3589                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (``self.fermie``).
3590                -  Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV
3591                -  None: Don't shift energies, equivalent to ``e0 = 0``.
3592            ax_list: The axes for the DOS and IDOS plot. If ax_list is None, a new figure
3593                is created and the two axes are automatically generated.
3594            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
3595                   or scalar e.g. ``left``. If left (right) is None, default values are used
3596            height_ratios:
3597            kwargs: options passed to ``plot_ax``
3598
3599        Return: |matplotlib-Figure|
3600        """
3601        import matplotlib.pyplot as plt
3602        from matplotlib.gridspec import GridSpec
3603
3604        if ax_list is None:
3605            fig = plt.figure()
3606            gspec = GridSpec(nrows=2, ncols=1, height_ratios=height_ratios, wspace=0.05)
3607            ax0 = plt.subplot(gspec[0])
3608            ax1 = plt.subplot(gspec[1], sharex=ax0)
3609            ax_list = [ax0, ax1]
3610
3611            for ax in ax_list:
3612                ax.grid(True)
3613                set_axlims(ax, xlims, "x")
3614
3615            ax0.set_ylabel("TOT IDOS")
3616            ax1.set_ylabel("TOT DOS")
3617            ax1.set_xlabel('Energy (eV)')
3618        else:
3619            fig = ax_list[0].get_figure()
3620
3621        for spin in range(self.nsppol):
3622            opts = {"color": "black", "linewidth": 1.0} if spin == 0 else \
3623                   {"color": "red", "linewidth": 1.0}
3624            # Plot Total dos if unpolarized.
3625            if self.nsppol == 1: spin = None
3626            self.plot_ax(ax_list[0], e0, spin=spin, what="idos", **opts)
3627            self.plot_ax(ax_list[1], e0, spin=spin, what="dos", **opts)
3628
3629        return fig
3630
3631    @add_fig_kwargs
3632    def plot_up_minus_down(self, e0="fermie", ax=None, xlims=None, **kwargs):
3633        """
3634        Plot Dos_up - Dow_down
3635
3636        Args:
3637            e0: Option used to define the zero of energy in the band structure plot. Possible values:
3638                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (``self.fermie``).
3639                -  Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV
3640                -  None: Don't shift energies, equivalent to ``e0 = 0``
3641            ax: |matplotlib-Axes| or None if a new figure should be created.
3642            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
3643                   or scalar e.g. ``left``. If left (right) is None, default values are used
3644            kwargs: options passed to ``ax.plot``.
3645
3646        Return: |matplotlib-Figure|
3647        """
3648        dos_diff = self.up_minus_down
3649        idos_diff = dos_diff.integral()
3650
3651        e0 = self.get_e0(e0)
3652        if not kwargs:
3653            kwargs = {"color": "black", "linewidth": 1.0}
3654
3655        ax, fig, plt = get_ax_fig_plt(ax=ax)
3656        ax.plot(dos_diff.mesh - e0, dos_diff.values, **kwargs)
3657        ax.plot(idos_diff.mesh - e0, idos_diff.values, **kwargs)
3658
3659        ax.grid(True)
3660        set_axlims(ax, xlims, "x")
3661        ax.set_ylabel('Dos_up - Dos_down (states/eV)')
3662        ax.set_xlabel('Energy (eV)')
3663
3664        return fig
3665
3666
3667class ElectronDosPlotter(NotebookWriter):
3668    """
3669    Class for plotting multiple electronic DOSes.
3670
3671    Usage example:
3672
3673    .. code-block:: python
3674
3675        plotter = ElectronDosPlotter()
3676        plotter.add_edos("foo dos", "foo.nc")
3677        plotter.add_edos("bar dos", "bar.nc")
3678        fig = plotter.gridplot()
3679    """
3680    # TODO: down-up option animate?
3681
3682    def __init__(self, key_edos=None, edos_kwargs=None):
3683        if key_edos is None: key_edos = []
3684        key_edos = [(k, ElectronDos.as_edos(v, edos_kwargs)) for k, v in key_edos]
3685        self.edoses_dict = OrderedDict(key_edos)
3686
3687    def __len__(self):
3688        return len(self.edoses_dict)
3689
3690    @property
3691    def edos_list(self):
3692        """List of DOSes"""
3693        return list(self.edoses_dict.values())
3694
3695    def add_edos(self, label, edos, edos_kwargs=None):
3696        """
3697        Adds a DOS for plotting.
3698
3699        Args:
3700            label: label for the DOS. Must be unique.
3701            edos: |ElectronDos| object.
3702            edos_kwargs: optional dictionary with the options passed to ``get_edos`` to compute the electron DOS.
3703                Used only if ``edos`` is not an ElectronDos instance.
3704        """
3705        if label in self.edoses_dict:
3706            raise ValueError("label %s is already in %s" % (label, list(self.edoses_dict.keys())))
3707        self.edoses_dict[label] = ElectronDos.as_edos(edos, edos_kwargs)
3708
3709    @add_fig_kwargs
3710    def combiplot(self, what_list="dos", spin_mode="automatic", e0="fermie",
3711                  ax_list=None,  xlims=None, fontsize=8, **kwargs):
3712        """
3713        Plot the the DOSes on the same figure. Use ``gridplot`` to plot DOSes on different figures.
3714
3715        Args:
3716            what_list: Selects quantities to plot e.g. ["dos", "idos"] to plot DOS and integrated DOS.
3717                "dos" for DOS only and "idos" for IDOS only
3718            spin_mode: "total" for total (I)DOS, "resolved" for plotting individual contributions.
3719                Meaningful only if nsppol == 2.
3720                "automatic" to use "resolved" if at least one DOS is polarized.
3721            e0: Option used to define the zero of energy in the band structure plot. Possible values:
3722                - ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (``self.fermie``).
3723                -  Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV
3724                -  None: Don't shift energies, equivalent to ``e0 = 0``
3725            ax_list: List of |matplotlib-Axes| or None if a new figure should be created.
3726            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
3727                   or scalar e.g. ``left``. If left (right) is None, default values are used
3728            fontsize (int): fontsize for titles and legend
3729
3730        Return: |matplotlib-Figure|
3731        """
3732        if spin_mode == "automatic":
3733            spin_mode = "resolved" if any(edos.nsppol == 2 for edos in self.edoses_dict.values()) else "total"
3734
3735        what_list = list_strings(what_list)
3736        nrows, ncols = len(what_list), 1
3737        ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
3738                                                sharex=True, sharey=False, squeeze=False)
3739        ax_list = ax_list.ravel()
3740
3741        can_use_basename = self._can_use_basenames_as_labels()
3742        for i, (what, ax) in enumerate(zip(what_list, ax_list)):
3743            for label, edos in self.edoses_dict.items():
3744                if can_use_basename:
3745                    label = os.path.basename(label)
3746                else:
3747                    # Use relative paths if label is a file.
3748                    if os.path.isfile(label): label = os.path.relpath(label)
3749
3750                # Here I handle spin and spin_mode.
3751                if edos.nsppol == 1 or spin_mode == "total":
3752                    # Plot total values
3753                    edos.plot_ax(ax, e0, what=what, spin=None, label=label)
3754
3755                elif spin_mode == "resolved":
3756                    # Plot spin resolved quantiies with sign.
3757                    # Note get_color to have same color for both spins.
3758                    lines = None
3759                    for spin in range(edos.nsppol):
3760                        fact = 1 if spin == 0 else -1
3761                        lines = edos.plot_ax(ax, e0, what=what, spin=spin, fact=fact,
3762                            color=None if spin == 0 else lines[0].get_color(),
3763                            label=label if spin == 0 else None)
3764                else:
3765                    raise ValueError("Wrong value for spin_mode: `%s`:" % str(spin_mode))
3766
3767            ax.grid(True)
3768            if i == len(what_list) - 1:
3769                ax.set_xlabel("Energy (eV)")
3770            ax.set_ylabel('DOS (states/eV)' if what == "dos" else "IDOS")
3771            set_axlims(ax, xlims, "x")
3772            ax.legend(loc="best", shadow=True, fontsize=fontsize)
3773
3774        return fig
3775
3776    # An alias for combiplot.
3777    plot = combiplot
3778
3779    @add_fig_kwargs
3780    def gridplot(self, what="dos", spin_mode="automatic", e0="fermie",
3781                 sharex=True, sharey=True, xlims=None, fontsize=8, **kwargs):
3782        """
3783        Plot multiple DOSes on a grid.
3784
3785        Args:
3786            what: "dos" to plot DOS, "idos" for integrated DOS.
3787            spin_mode: "total" for total (I)DOS, "resolved" for plotting individual contributions.
3788                Meaningful only if nsppol == 2.
3789                "automatic" to use "resolved" if at least one DOS is polarized.
3790            e0: Option used to define the zero of energy in the band structure plot. Possible values::
3791
3792                - ``fermie``: shift all eigenvalues and the DOS to have zero energy at the Fermi energy.
3793                   Note that, by default, the Fermi energy is taken from the band structure object
3794                   i.e. the Fermi energy computed at the end of the SCF file that produced the density.
3795                   This should be ok in semiconductors. In metals, however, a better value of the Fermi energy
3796                   can be obtained from the DOS provided that the k-sampling for the DOS is much denser than
3797                   the one used to compute the density. See ``edos_fermie``.
3798                - ``edos_fermie``: Use the Fermi energy computed from the DOS to define the zero of energy in both subplots.
3799                   Available only if edos_objects is not None
3800                -  Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV
3801                -  None: Don't shift energies, equivalent to ``e0 = 0``.
3802
3803            sharex, sharey: True if x (y) axis should be shared.
3804            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
3805                   or scalar e.g. ``left``. If left (right) is None, default values are used
3806            fontsize: Label and title fontsize.
3807
3808        Return: |matplotlib-Figure|
3809        """
3810        if spin_mode == "automatic":
3811            spin_mode = "resolved" if any(edos.nsppol == 2 for edos in self.edoses_dict.values()) else "total"
3812
3813        titles = list(self.edoses_dict.keys())
3814        edos_list = self.edos_list
3815
3816        nrows, ncols = 1, 1
3817        numeb = len(edos_list)
3818        if numeb > 1:
3819            ncols = 2
3820            nrows = numeb // ncols + numeb % ncols
3821
3822        # Build Grid
3823        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
3824                                                sharex=sharex, sharey=sharey, squeeze=False)
3825        ax_list = ax_list.ravel()
3826
3827        # don't show the last ax if numeb is odd.
3828        if numeb % ncols != 0: ax_list[-1].axis("off")
3829
3830        for i, ((label, edos), ax) in enumerate(zip(self.edoses_dict.items(), ax_list)):
3831            irow, icol = divmod(i, ncols)
3832
3833            # Here I handle spin and spin_mode.
3834            if edos.nsppol == 1 or spin_mode == "total":
3835                opts = {"color": "black", "linewidth": 1.0}
3836                edos.plot_ax(ax, e0=e0, what=what, spin=None, **opts)
3837
3838            elif spin_mode == "resolved":
3839                # Plot spin resolved quantiies with sign.
3840                # Note get_color to have same color for both spins.
3841                lines = None
3842                for spin in range(edos.nsppol):
3843                    fact = 1 if spin == 0 else -1
3844                    lines = edos.plot_ax(ax, e0, what=what, spin=spin, fact=fact,
3845                        color=None if spin == 0 else lines[0].get_color(),
3846                        label=label if spin == 0 else None)
3847            else:
3848                raise ValueError("Wrong value for spin_mode: `%s`:" % str(spin_mode))
3849
3850            ax.grid(True)
3851            ax.set_title(label, fontsize=fontsize)
3852            set_axlims(ax, xlims, "x")
3853            if (irow, icol) == (0, 0):
3854                ax.set_ylabel('DOS (states/eV)' if what == "dos" else "IDOS")
3855            if irow == nrows - 1:
3856                ax.set_xlabel("Energy (eV)")
3857
3858            #ax.legend(loc="best", shadow=True, fontsize=fontsize)
3859
3860        return fig
3861
3862    def ipw_select_plot(self): # pragma: no cover
3863        """
3864        Return an ipython widget with controllers to select the plot.
3865        """
3866        def plot_callback(plot_type, e0):
3867            getattr(self, plot_type)(e0=e0, show=True)
3868
3869        import ipywidgets as ipw
3870        return ipw.interact_manual(
3871                plot_callback,
3872                plot_type=["combiplot", "gridplot"],
3873                e0=["fermie", "0.0"],
3874            )
3875
3876    def yield_figs(self, **kwargs):  # pragma: no cover
3877        """
3878        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
3879        """
3880        yield self.combiplot(show=False)
3881        yield self.gridplot(show=False)
3882
3883    def write_notebook(self, nbpath=None):
3884        """
3885        Write a jupyter_ notebook to nbpath. If nbpath is None, a temporay file in the current
3886        working directory is created. Return path to the notebook.
3887        """
3888        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
3889
3890        # Use pickle files for data persistence.
3891        tmpfile = self.pickle_dump()
3892
3893        nb.cells.extend([
3894            nbv.new_markdown_cell("# This is a markdown cell"),
3895            nbv.new_code_cell("plotter = abilab.ElectronDosPlotter.pickle_load('%s')" % tmpfile),
3896            nbv.new_code_cell("print(plotter)"),
3897            nbv.new_code_cell("xlims = (None, None)"),
3898            nbv.new_code_cell("plotter.combiplot(xlims=xlims);"),
3899            nbv.new_code_cell("plotter.gridplot(xlims=xlims);"),
3900        ])
3901
3902        return self._write_nb_nbpath(nb, nbpath)
3903
3904    def _can_use_basenames_as_labels(self):
3905        """
3906        Return True if all labels represent valid files and the basenames are unique
3907        In this case one can use the file basename instead of the full path in the plots.
3908        """
3909        if not all(os.path.exists(l) for l in self.edoses_dict): return False
3910        labels = [os.path.basename(l) for l in self.edoses_dict]
3911        return len(set(labels)) == len(labels)
3912
3913
3914class Bands3D(Has_Structure):
3915
3916    def __init__(self, structure, ibz, has_timrev, eigens, fermie):
3917        """
3918        This object reconstructs by symmetry the eigenvalues in the full BZ starting from the IBZ.
3919        Provides methods to extract and visualize isosurfaces.
3920
3921        Args:
3922            structure:
3923            ibz:
3924            has_timrev:
3925            eigens:
3926            fermie
3927        """
3928        self.ibz = ibz
3929        self._structure = structure
3930        self.reciprocal_lattice = structure.lattice.reciprocal_lattice
3931        self.has_timrev = has_timrev
3932        self.fermie = fermie
3933        self.eigens = np.atleast_3d(eigens)
3934        self.nsppol, _, self.nband = self.eigens.shape
3935
3936        # Sanity check.
3937        errors = []; eapp = errors.append
3938        if not self.ibz.is_ibz:
3939            eapp("Expecting an IBZ sampling but got %s" % type(self.ibz))
3940        if not self.ibz.is_mpmesh:
3941            eapp("Monkhorst-Pack meshes are required.\nksampling: %s" % str(self.ibz.ksampling))
3942
3943        mpdivs, shifts = self.ibz.mpdivs_shifts
3944        if shifts is not None and not np.all(shifts == 0.0):
3945            eapp("Gamma-centered k-meshes are required by Xcrysden.")
3946        if errors:
3947            raise ValueError("\n".join(errors))
3948
3949        # Xcrysden requires points in the unit cell (C-order)
3950        # and the mesh must include the periodic images hence pbc=True.
3951        self.uc2ibz = map_grid2ibz(self.structure, self.ibz.frac_coords, mpdivs, self.has_timrev, pbc=True)
3952        self.mpdivs = mpdivs
3953        self.kdivs = mpdivs + 1
3954        self.spacing = 1.0 / mpdivs
3955        self.ucdata_shape = (self.nsppol, self.nband) + tuple(self.kdivs)
3956        #self.ibzdata_shape = (self.nsppol, self.nband, len(self.ibz))
3957
3958        # Construct energy bands on unit cell grid: e_{TSk} = e_{k}
3959        self.ucdata_sbk = self.symmetrize_ibz_scalars(self.eigens)
3960
3961        self.ucell_scalars = OrderedDict()
3962        self.ucell_vectors = OrderedDict()
3963        #if reference_sb is None:
3964        #self.reference_sb = [[] for _ in self.spins]
3965        #for spin in self.spins:
3966        #    self.reference_sb[spin] = {band: band for band in self.bands}
3967        #else:
3968
3969    @property
3970    def structure(self):
3971        """|Structure| object."""
3972        return self._structure
3973
3974    # Handy variables used to loop
3975    @property
3976    def spins(self):
3977        return range(self.nsppol)
3978
3979    @property
3980    def bands(self):
3981        return range(self.nband)
3982
3983    def __str__(self):
3984        return self.to_string()
3985
3986    def to_string(self, verbose=0):
3987        """String representation."""
3988        lines = []
3989        app = lines.append
3990        # TODO: Finalize implementation
3991        app(self.structure.to_string(verbose=verbose, title="Structure"))
3992        app("")
3993
3994        return "\n".join(lines)
3995
3996    def add_ucell_scalars(self, name, scalars):
3997        """
3998        Add scalar quantities given in the unit cell.
3999
4000        Args:
4001            name: keyword used to store scalars.
4002            scalars:
4003        """
4004        self.ucell_scalars[name] = np.reshape(scalars, self.ucdata_shape)
4005
4006    def add_ibz_scalars(self, name, scalars, inshape="skb"):
4007        """
4008        Add scalar quantities given in the IBZ i.e. symmetrize values to get array in unit cell.
4009
4010        Args:
4011            name: keyword used to store symmetrized values.
4012            scalars: scalars in IBZ. See ``inshape`` for shape
4013            inshape: shape of input scalars. "skb" if (nsppol, nkibz, nband)
4014            "sbk" for (nsppol, nband, nkibz).
4015        """
4016        self.add_ucell_scalars(name, self.symmetrize_ibz_scalars(scalars, inshape=inshape))
4017
4018    def symmetrize_ibz_scalars(self, scalars, inshape="skb"):
4019        """
4020        Symmetrize scalar quantities given in the IBZ.
4021
4022        Args:
4023            scalars: scalars in IBZ. See `inshape` for shape
4024            inshape: shape of input scalars. "skb" if (nsppol, nkibz, nband)
4025            "sbk" for (nsppol, nband, nkibz).
4026
4027        Return:
4028            |numpy-array| with scalars in unit cell. shape is **always**: (nsppol, nband, nkbz)
4029        """
4030        # Symmetrize scalars unit cell grid: e_{TSk} = e_{k}
4031        ucdata_sbk = np.empty((self.nsppol, self.nband, len(self.uc2ibz)))
4032
4033        if inshape == "skb":
4034            scalars = np.reshape(scalars, (self.nsppol, len(self.ibz), self.nband))
4035            for ikuc, ik_ibz in enumerate(self.uc2ibz):
4036                ucdata_sbk[:, :, ikuc] = scalars[:, ik_ibz, :]
4037        elif inshape == "sbk":
4038            scalars = np.reshape(scalars, (self.nsppol, self.nband, len(self.ibz)))
4039            for ikuc, ik_ibz in enumerate(self.uc2ibz):
4040                ucdata_sbk[:, :, ikuc] = scalars[:, :, ik_ibz]
4041        else:
4042            raise ValueError("Wrong inshape: %s" % str(inshape))
4043
4044        return ucdata_sbk
4045
4046    #def add_ucell_vectors(self, name, vectors, inshape="skb"):
4047    #    self.ucell_vectors[name] = np.reshape(vectors, self.ucdata + (3,))
4048
4049    #def add_ibz_vectors(self, name, scalars, inshape="skb")
4050    #    self.add_ucell_vectors(name, self.symmetrize_ibz_vectors(vectors, inshape=inshape))
4051
4052    #def wsmap(self):
4053        #ws = -np.ones(ngkpt, dtype=np.int)
4054        #for i in range(ngkpt[0]):
4055        #    ki = (i - ngkpt[0] // 2)
4056        #    if ki < 0: ki += ngkpt[0]
4057        #    for j in range(ngkpt[1]):
4058        #        kj = (j - ngkpt[1] // 2)
4059        #        if kj < 0: kj += ngkpt[1]
4060        #        for k in range(ngkpt[2]):
4061        #            kz = (k - ngkpt[2] // 2)
4062        #            if kz < 0: kz += ngkpt[2]
4063        #            #bzgrid2ibz[gp_bz[0], gp_bz[1], gp_bz[2]] = ik_ibz
4064        #            ws[i, j, k] = bzgrid2ibz[ki, kj, kz]
4065        #bzgrid2ibz = ws
4066
4067    def get_isobands(self, e0):
4068        """Return index of the bands crossing ``e0``in eV. None if no band is found."""
4069        isobands = [[] for _ in self.spins]
4070        for spin in self.spins:
4071            for band in self.bands:
4072                emin, emax = self.eigens[spin, :, band].min(), self.eigens[spin, :, band].max()
4073                if isobands[spin] and e0 > emax: break
4074                if emax >= e0 >= emin: isobands[spin].append(band)
4075        if all(not l for l in isobands): return None
4076        return isobands
4077
4078    def xcrysden_view(self):  # pragma: no cover
4079        """
4080        Visualize electron energy isosurfaces with xcrysden_.
4081        """
4082        _, tmp_filepath = tempfile.mkstemp(suffix=".bxsf", text=True)
4083        print("Producing BXSF file in:", tmp_filepath)
4084        self.to_bxsf(tmp_filepath, unit="eV")
4085        from abipy.iotools.visualizer import Xcrysden
4086        return Xcrysden(tmp_filepath)()
4087
4088    def to_bxsf(self, filepath, unit="eV"):
4089        """
4090        Export the full band structure to ``filepath`` in BXSF format
4091        suitable for the visualization of the Fermi surface with xcrysden_ (use ``xcrysden --bxsf FILE``).
4092        Require k-points in IBZ and gamma-centered k-mesh.
4093
4094        Args:
4095            filepath: BXSF filename or stream.
4096            unit: Input energies are in unit ``unit``.
4097        """
4098        from abipy.iotools import bxsf_write
4099        if hasattr(filepath, "write"):
4100            return bxsf_write(filepath, self.structure, self.nsppol, self.nband, self.kdivs,
4101                              self.ucdata_sbk, self.fermie, unit=unit)
4102        else:
4103            with open(filepath, "wt") as fh:
4104                bxsf_write(fh, self.structure, self.nsppol, self.nband, self.kdivs,
4105                           self.ucdata_sbk, self.fermie, unit=unit)
4106                return filepath
4107
4108    def get_e0(self, e0):
4109        """
4110        e0: Option used to define the zero of energy in the band structure plot. Possible values:
4111                - `fermie`: shift all eigenvalues to have zero energy at the Fermi energy (`self.fermie`).
4112                -  Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
4113                -  None: Don't shift energies, equivalent to e0=0
4114        """
4115        if e0 is None:
4116            return 0.0
4117        elif is_string(e0):
4118            if e0 == "fermie":
4119                return self.fermie
4120            elif e0 == "None":
4121                return 0.0
4122            else:
4123                raise ValueError("Wrong value for e0: %s" % e0)
4124        else:
4125            # Assume number
4126            return e0
4127
4128    @add_fig_kwargs
4129    def plot_isosurfaces(self, e0="fermie", cmap=None, verbose=0, **kwargs):
4130        """
4131        Plot isosurface with matplotlib_
4132
4133        .. warning::
4134
4135            Requires scikit-image package, matplotlib rendering is usually slow.
4136
4137        Args:
4138            e0: Isolevel in eV. Default: Fermi energy.
4139            verbose: verbosity level.
4140
4141        Return: |matplotlib-Figure|
4142        """
4143        try:
4144            from skimage.measure import marching_cubes_lewiner as marching_cubes
4145        except ImportError:
4146            try:
4147                from skimage.measure import marching_cubes
4148            except ImportError:
4149                raise ImportError("scikit-image not installed.\n"
4150                    "Please install with it with `conda install scikit-image` or `pip install scikit-image`")
4151
4152        e0 = self.get_e0(e0)
4153        isobands = self.get_isobands(e0)
4154        if isobands is None: return None
4155        if verbose: print("Bands for isosurface:", isobands)
4156
4157        #from pymatgen.electronic_structure.plotter import plot_lattice_vectors, plot_wigner_seitz
4158        ax, fig, plt = get_ax3d_fig_plt(ax=None)
4159        plot_unit_cell(self.reciprocal_lattice, ax=ax, color="k", linewidth=1)
4160        #plot_wigner_seitz(self.reciprocal_lattice, ax=ax, color="k", linewidth=1)
4161
4162        for spin in self.spins:
4163            for ib, band in enumerate(isobands[spin]):
4164                # From http://scikit-image.org/docs/stable/api/skimage.measure.html#marching-cubes
4165                # verts: (V, 3) array
4166                #   Spatial coordinates for V unique mesh vertices. Coordinate order matches input volume (M, N, P).
4167                # faces: (F, 3) array
4168                #   Define triangular faces via referencing vertex indices from verts.
4169                #   This algorithm specifically outputs triangles, so each face has exactly three indices.
4170                # normals: (V, 3) array
4171                #   The normal direction at each vertex, as calculated from the data.
4172                # values: (V, ) array
4173                #   Gives a measure for the maximum value of the data in the local region near each vertex.
4174                #   This can be used by visualization tools to apply a colormap to the mesh
4175                voldata = np.reshape(self.ucdata_sbk[spin, band], self.kdivs)
4176                verts, faces, normals, values = marching_cubes(voldata, level=e0, spacing=tuple(self.spacing))
4177                #verts, faces, normals, values = marching_cubes_lewiner(voldata, level=e0, spacing=tuple(self.spacing))
4178                verts = self.reciprocal_lattice.get_cartesian_coords(verts)
4179
4180                if cmap is not None:
4181                    cmap = plt.get_cmap(cmap)
4182                    kwargs["color"] = cmap(float(ib) / len(isobands[spin]))
4183
4184                ax.plot_trisurf(verts[:, 0], verts[:, 1], faces, verts[:, 2], **kwargs)
4185                    #, cmap='Spectral', lw=1, antialiased=True)
4186
4187                # mayavi package:
4188                #mlab.triangular_mesh([v[0] for v in verts], [v[1] for v in verts], [v[2] for v in verts], faces)
4189                #, color=(0, 0, 0))
4190
4191        ax.set_axis_off()
4192
4193        return fig
4194
4195    def mvplot_isosurfaces(self, e0="fermie", verbose=0, figure=None, show=True):  # pragma: no cover
4196        """
4197        Plot isosurface with mayavi_
4198
4199        Args:
4200            e0:
4201            verbose:
4202            show:
4203        """
4204        # Find bands crossing e0.
4205        e0 = self.get_e0(e0)
4206        isobands = self.get_isobands(e0)
4207        if isobands is None: return None
4208        if verbose: print("Bands for isosurface:", isobands)
4209
4210        #from pymatgen.electronic_structure.plotter import plot_fermi_surface
4211        #spin, band = 0, 4
4212        #for i, band in enumerate(isobands[spin]):
4213        #data = np.reshape(isoenes[0][band], mpdivs + 1 if pbc else mpdivs)
4214        #plot_fermi_surface(data, self.structure, False, energy_levels=[e0]) interative=not (i == len(isobands[spin]) - 1))
4215
4216        # Plot isosurface with mayavi.
4217        from abipy.display import mvtk
4218        figure, mlab = mvtk.get_fig_mlab(figure=figure)
4219        mvtk.plot_unit_cell(self.reciprocal_lattice, figure=figure)
4220        mvtk.plot_wigner_seitz(self.reciprocal_lattice, figure=figure)
4221        cell = self.reciprocal_lattice.matrix
4222
4223        for spin in self.spins:
4224            for band in isobands[spin]:
4225                data = np.reshape(self.ucdata_sbk[spin, band], self.kdivs)
4226                cp = mlab.contour3d(data, contours=[e0], transparent=True,
4227                                    #colormap='hot', color=(0, 0, 1), opacity=1.0, figure=figure)
4228                                    colormap='Set3', opacity=0.9, figure=figure)
4229
4230                polydata = cp.actor.actors[0].mapper.input
4231                pts = np.array(polydata.points) #  - 1  # TODO this + mpdivs should be correct
4232                if verbose: print("shape:", pts.shape, pts)
4233                polydata.points = np.dot(pts, cell / np.array(data.shape)[:, np.newaxis])
4234                #polydata.points = np.dot(pts, cell / np.array(self.mpdivs)[:, np.newaxis])
4235                mlab.view(distance="auto", figure=figure)
4236
4237        # Add k-point labels.
4238        labels = {k.name: k.frac_coords for k in self.structure.hsym_kpoints}
4239        mvtk.plot_labels(labels, lattice=self.structure.reciprocal_lattice, figure=figure)
4240
4241        if show: mlab.show()
4242        return figure
4243
4244    @add_fig_kwargs
4245    def plot_contour(self, band, spin=0, plane="xy", elevation=0, ax=None, fontsize=8, **kwargs):
4246        """
4247        Contour plot with matplotlib_.
4248
4249        Args:
4250            band: Band index
4251            spin: Spin index.
4252            plane:
4253            elevation:
4254            ax: |matplotlib-Axes| or None if a new figure should be created.
4255            fontsize: Label and title fontsize.
4256
4257        Return: |matplotlib-Figure|
4258        """
4259        data = np.reshape(self.ucdata_sbk[spin, band], self.kdivs) - self.fermie
4260
4261        x = np.arange(self.kdivs[0]) / (self.kdivs[0] - 1)
4262        y = np.arange(self.kdivs[1]) / (self.kdivs[1] - 1)
4263        fxy = data[:, :, elevation]
4264
4265        ax, fig, plt = get_ax_fig_plt(ax=ax)
4266        x, y = np.meshgrid(x, y)
4267        c = ax.contour(x, y, fxy, **kwargs)
4268        ax.clabel(c, inline=1, fontsize=fontsize)
4269        kvert = dict(xy="z", xz="y", yz="x")[plane]
4270        ax.set_title(r"Band %s in %s plane at $K_{%s}=%d$" % (band, plane, kvert, elevation), fontsize=fontsize)
4271        ax.grid(True)
4272        ax.set_xlabel("$K_%s$" % plane[0])
4273        ax.set_ylabel("$K_%s$" % plane[1])
4274
4275        return fig
4276
4277    #def interpolate(self, densify_mpdivs):
4278        #densify_mpdivs = np.array(densify_mpdivs)
4279        #if np.any(densify_mpdivs > 1):
4280        #dense_mpdivs = densify_mpdivs * mpdivs
4281        #dense_kpts = kmesh_from_mpdivs(dense_mpdivs, shifts=(0, 0, 0), pbc=pbc, order="unit_cell")
4282        #from scipy.interpolate import RegularGridInterpolator
4283        #x = np.arange(0, mpdivs[0] + 1) / mpdivs[0]
4284        #y = np.arange(0, mpdivs[1] + 1) / mpdivs[1]
4285        #z = np.arange(0, mpdivs[2] + 1) / mpdivs[2]
4286        #for spin in self.spins:
4287        #    for band in isobands[spin]:
4288        #        interp = RegularGridInterpolator((x, y, z), isoenes[spin][band], method='linear')
4289        #        isoenes[spin][band] = interp(dense_mpdivs)
4290
4291    #def mvplot_surf(self):
4292        #spin, band = 0, 2
4293        #data = np.reshape(isoenes[spin][band], mpdivs + 1 if pbc else mpdivs)
4294        #x, y = np.arange(data.shape[0]), np.arange(data.shape[1])
4295        #cp = mlab.surf(x, y, data[0,:,:], figure=figure)
4296        #polydata = cp.actor.actors[0].mapper.input
4297        #pts = np.array(polydata.points) # - 1
4298        #xs, ys, zs = pts.T
4299        #print(pts.shape)
4300        #print(zs)
4301        #polydata.points = np.dot(pts, cell / np.array(data.shape)[:, np.newaxis])
4302
4303        #data = np.reshape(isoenes[spin][band+1], mpdivs + 1 if pbc else mpdivs)
4304        #cp = mlab.surf(x, y, data[0,:,:], figure=figure)
4305        #polydata = cp.actor.actors[0].mapper.input
4306        #pts = np.array(polydata.points) # - 1
4307        #polydata.points = np.dot(pts, cell / np.array(data.shape)[:, np.newaxis])
4308        #mlab.view(distance="auto", figure=figure)
4309        #if show: mlab.show()
4310        #return
4311
4312    def mvplot_cutplanes(self, band, spin=0, figure=None, show=True, **kwargs): # pragma: no cover
4313        """Plot cutplanes with mayavi_."""
4314        data = np.reshape(self.ucdata_sbk[spin, band], self.kdivs) - self.fermie
4315        contours = [-1.0, 0.0, 1.0]
4316
4317        from abipy.display import mvtk
4318        figure, mlab = mvtk.get_fig_mlab(figure=figure)
4319        src = mlab.pipeline.scalar_field(data)
4320
4321        mlab.pipeline.image_plane_widget(src, plane_orientation='x_axes', slice_index=self.kdivs[0]//2)
4322        mlab.pipeline.image_plane_widget(src, plane_orientation='y_axes', slice_index=self.kdivs[1]//2)
4323        mlab.pipeline.image_plane_widget(src, plane_orientation='z_axes', slice_index=self.kdivs[2]//2)
4324        mlab.pipeline.iso_surface(src, contours=contours) #, opacity=0.1)
4325        #mlab.pipeline.iso_surface(src, contours=[data.min()+ 0.1 * data.ptp()], opacity=0.1)
4326        mlab.outline()
4327
4328        if show: mlab.show()
4329        return figure
4330
4331    #def write_data(self, workdir, fmt="cube", rmdir=False)
4332
4333
4334class ElectronBands3D(Bands3D):
4335    pass
4336    #def make_fermisurfer_dir(self, workdir)
4337
4338#class PhononBands3D(Bands3D):
4339#    pass
4340
4341
4342class RobotWithEbands(object):
4343    """
4344    Mixin class for robots associated to files with |ElectronBands|.
4345    """
4346    def combiplot_ebands(self, **kwargs):
4347        """Wraps combiplot method of |ElectronBandsPlotter|. kwargs passed to combiplot."""
4348        return self.get_ebands_plotter().combiplot(**kwargs)
4349
4350    def gridplot_ebands(self, **kwargs):
4351        """Wraps gridplot method of |ElectronBandsPlotter|. kwargs passed to gridplot."""
4352        return self.get_ebands_plotter().gridplot(**kwargs)
4353
4354    def boxplot_ebands(self, **kwargs):
4355        """Wraps boxplot method of |ElectronBandsPlotter|. kwargs passed to boxplot."""
4356        return self.get_ebands_plotter().boxplot(**kwargs)
4357
4358    def combiboxplot_ebands(self, **kwargs):
4359        """Wraps combiboxplot method of |ElectronDosPlotter|. kwargs passed to combiboxplot."""
4360        return self.get_ebands_plotter().combiboxplot(**kwargs)
4361
4362    def combiplot_edos(self, **kwargs):
4363        """Wraps combiplot method of |ElectronDosPlotter|. kwargs passed to combiplot."""
4364        return self.get_edos_plotter().combiplot(**kwargs)
4365
4366    def gridplot_edos(self, **kwargs):
4367        """Wraps gridplot method of |ElectronDosPlotter|. kwargs passed to gridplot."""
4368        return self.get_edos_plotter().gridplot(**kwargs)
4369
4370    def get_ebands_plotter(self, kselect=None, filter_abifile=None, cls=None):
4371        """
4372        Build and return an instance of |ElectronBandsPlotter| or a subclass if ``cls`` is not None.
4373
4374        Args:
4375            kselect (str): Used to select particula `ebands`.
4376                "path" to select bands given on a k-path, "ibz" for bands with IBZ sampling.
4377                None has not effect
4378            filter_abifile: Function that receives an ``abifile`` object and returns
4379                True if the file should be added to the plotter.
4380            cls: subclass of |ElectronBandsPlotter|.
4381        """
4382        plotter = ElectronBandsPlotter() if cls is None else cls()
4383
4384        for label, abifile in self.items():
4385            if filter_abifile is not None and not filter_abifile(abifile): continue
4386            if kselect is not None:
4387                if kselect == "path" and not abifile.ebands.kpoints.is_path: continue
4388                if kselect == "ibz" and not abifile.ebands.kpoints.is_ibz: continue
4389            plotter.add_ebands(label, abifile.ebands)
4390
4391        return plotter
4392
4393    def get_edos_plotter(self, cls=None, filter_abifile=None, **kwargs):
4394        """
4395        Build and return an instance of |ElectronDosPlotter| or a subclass is cls is not None.
4396
4397        Args:
4398            filter_abifile: Function that receives an ``abifile` object and returns
4399                True if the file should be added to the plotter.
4400            cls: subclass of |ElectronDosPlotter|.
4401            kwargs: Arguments passed to ebands.get_edos
4402        """
4403        plotter = ElectronDosPlotter() if cls is None else cls()
4404
4405        for label, abifile in self.items():
4406            if filter_abifile is not None and not filter_abifile(abifile): continue
4407            if not abifile.ebands.kpoints.is_ibz:
4408                cprint("Skipping %s because kpoint sampling not IBZ" % abifile.filepath, "magenta")
4409                continue
4410            plotter.add_edos(label, abifile.ebands.get_edos(**kwargs))
4411
4412        return plotter
4413
4414    #def get_ebands_dataframe(self, with_spglib=True):
4415    #    return dataframe_from_ebands(self.ncfiles, index=list(self.keys()), with_spglib=with_spglib)
4416
4417    @add_fig_kwargs
4418    def plot_egaps(self, sortby=None, hue=None, fontsize=6, **kwargs):
4419        """
4420        Plot the convergence of the direct and fundamental gaps
4421        wrt to the ``sortby`` parameter. Values can optionally be grouped by ``hue``.
4422
4423        Args:
4424            sortby: Define the convergence parameter, sort files and produce plot labels.
4425                Can be None, string or function. If None, no sorting is performed.
4426                If string and not empty it's assumed that the abifile has an attribute
4427                with the same name and `getattr` is invoked.
4428                If callable, the output of sortby(abifile) is used.
4429            hue: Variable that define subsets of the data, which will be drawn on separate lines.
4430                Accepts callable or string
4431                If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
4432                If callable, the output of hue(abifile) is used.
4433            fontsize: legend and label fontsize.
4434
4435        Returns: |matplotlib-Figure|
4436        """
4437        # Note: Handling nsppol > 1 and the case in which we have abifiles with different nsppol is a bit tricky
4438        # hence we have to handle the different cases explicitly (see get_xy)
4439        if not self.abifiles: return None
4440        max_nsppol = max(f.nsppol for f in self.abifiles)
4441
4442        items = ["fundamental_gaps", "direct_gaps", "bandwidths"]
4443
4444        def get_xy(item, spin, all_xvals, all_abifiles):
4445            """
4446            Extract (xvals, yvals) from all_abifiles for given (item, spin) and initial all_xvals.
4447            Here we handle the case in which we have files with different nsppol.
4448            """
4449            xvals, yvals = [], []
4450
4451            for i, af in enumerate(all_abifiles):
4452                if spin > af.nsppol - 1: continue
4453                xvals.append(all_xvals[i])
4454                if callable(item):
4455                    yy = float(item(af.ebands))
4456                else:
4457                    yy = getattr(af.ebands, item)
4458                    if item in ("fundamental_gaps", "direct_gaps"):
4459                        yy = yy[spin].energy
4460                    else:
4461                        yy = yy[spin]
4462
4463                yvals.append(yy)
4464
4465            return xvals, yvals
4466
4467        # Build grid plot.
4468        nrows, ncols = len(items), 1
4469        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
4470                                                sharex=True, sharey=False, squeeze=False)
4471        ax_list = ax_list.ravel()
4472
4473        # Sort and group files if hue.
4474        if hue is None:
4475            labels, ncfiles, params = self.sortby(sortby, unpack=True)
4476        else:
4477            groups = self.group_and_sortby(hue, sortby)
4478
4479        marker_spin = {0: "^", 1: "v"}
4480        for i, (ax, item) in enumerate(zip(ax_list, items)):
4481            for spin in range(max_nsppol):
4482                if hue is None:
4483                    # Extract data.
4484                    xvals, yvals = get_xy(item, spin, params, self.abifiles)
4485                    if not is_string(xvals[0]):
4486                        ax.plot(xvals, yvals, marker=marker_spin[spin], **kwargs)
4487                    else:
4488                        # Must handle list of strings in a different way.
4489                        xn = range(len(xvals))
4490                        ax.plot(xn, yvals, marker=marker_spin[spin], **kwargs)
4491                        ax.set_xticks(xn)
4492                        ax.set_xticklabels(xvals, fontsize=fontsize)
4493                else:
4494                    for g in groups:
4495                        # Extract data.
4496                        xvals, yvals = get_xy(item, spin, g.xvalues, g.abifiles)
4497                        label = "%s: %s" % (self._get_label(hue), g.hvalue)
4498                        ax.plot(xvals, yvals, label=label, marker=marker_spin[spin], **kwargs)
4499
4500            ax.grid(True)
4501            ax.set_ylabel(self._get_label(item))
4502            if i == len(items) - 1:
4503                ax.set_xlabel("%s" % self._get_label(sortby))
4504                if sortby is None: rotate_ticklabels(ax, 15)
4505            if i == 0:
4506                ax.legend(loc="best", fontsize=fontsize, shadow=True)
4507
4508        return fig
4509
4510    def get_ebands_code_cells(self, title=None):
4511        """Return list of notebook cells."""
4512        nbformat, nbv = self.get_nbformat_nbv()
4513        title = "## Code to compare multiple ElectronBands objects" if title is None else str(title)
4514        # Try not pollute namespace with lots of variables.
4515        return [
4516            nbv.new_markdown_cell(title),
4517            nbv.new_code_cell("robot.get_ebands_plotter().ipw_select_plot();"),
4518            nbv.new_code_cell("robot.get_edos_plotter().ipw_select_plot();"),
4519            nbv.new_code_cell("#robot.plot_egaps(sorby=None, hue=None);"),
4520        ]
4521
4522    @add_fig_kwargs
4523    def gridplot_with_hue(self, hue, ylims=None, fontsize=8, sharex=False, sharey=False, **kwargs):
4524        """
4525        Plot multiple electron bandstructures on a grid. Group bands by ``hue``.
4526
4527        Example:
4528
4529            robot.gridplot_with_hue("nkpt")
4530
4531        Args:
4532            hue: Variable that define subsets of the phonon bands, which will be drawn on separate plots.
4533                Accepts callable or string
4534                If string, it's assumed that `abifile has an attribute with the same name and getattr is invoked.
4535                Dot notation is also supported e.g. hue="structure.formula" --> abifile.structure.formula
4536                If callable, the output of hue(abifile) is used.
4537            ylims: Set the data limits for the y-axis. Accept tuple e.g. `(left, right)`
4538                or scalar e.g. `left`. If left (right) is None, default values are used
4539            fontsize: legend and title fontsize.
4540            sharex, sharey: True if X and Y axes should be shared.
4541
4542        Returns: |matplotlib-Figure|
4543        """
4544        # Group abifiles by hue.
4545        groups = self.group_and_sortby(hue, func_or_string=None)
4546        nrows, ncols = len(groups), 1
4547
4548        # Plot grid with phonon bands only.
4549        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
4550                                                sharex=sharex, sharey=sharey, squeeze=False)
4551        ax_list = ax_list.ravel()
4552        e0 = "fermie"  # Each ebands is aligned with respect to its Fermi energy.
4553
4554        for ax, grp in zip(ax_list, groups):
4555            ax.grid(True)
4556            ebands_list = [abifile.ebands for abifile in grp.abifiles]
4557            ax.set_title("%s = %s" % (self._get_label(hue), grp.hvalue), fontsize=fontsize)
4558
4559            nkpt_list = [ebands.nkpt for ebands in ebands_list]
4560            if any(nk != nkpt_list[0] for nk in nkpt_list):
4561                cprint("WARNING: Bands have different number of k-points:\n%s" % str(nkpt_list), "yellow")
4562
4563            for i, (ebands, lineopts) in enumerate(zip(ebands_list, self.iter_lineopt())):
4564                # Plot all branches with lineopts and set the label of the last line produced.
4565                ebands.plot_ax(ax, e0, **lineopts)
4566                ax.lines[-1].set_label("%s" % grp.labels[i])
4567
4568                # Set ticks and labels
4569                # (NB: we do this only for the first ebands, in principle ebands
4570                # in the group could have different k-points but there's need to be so strict here.
4571                if i == 0:
4572                    ebands.decorate_ax(ax, klabels=None)
4573
4574            # Set legends.
4575            ax.legend(loc='best', fontsize=fontsize, shadow=True)
4576            set_axlims(ax, ylims, "y")
4577
4578        return fig
4579
4580
4581from abipy.core.mixins import TextFile #, AbinitNcFile, NotebookWriter
4582from abipy.abio.robots import Robot
4583
4584
4585def find_yaml_section_in_lines(lines, tag):
4586
4587    magic = f"--- !{tag}"
4588    in_doc, buf = False, []
4589
4590    for line in lines:
4591        if line.startswith("#"):
4592            for i, c in enumerate(line):
4593                if c != "#": break
4594            line = line[i:]
4595
4596        if line.startswith(magic):
4597            in_doc = True
4598            continue
4599
4600        if in_doc and line.startswith("..."):
4601            in_doc = False
4602            break
4603
4604        if in_doc:
4605            buf.append(line.strip())
4606
4607    if not buf:
4608        raise ValueError(f"Cannot fine Yaml tag: `{magic}`")
4609
4610    import ruamel.yaml as yaml
4611    return yaml.safe_load("\n".join(buf))
4612
4613
4614class EdosFile(TextFile):
4615    """
4616    This object provides an interface to the _EDOS file
4617    (electron DOS usually computed with the tetrahedron method).
4618    The EdosFile has an ElectronDos edos object.
4619
4620    .. rubric:: Inheritance Diagram
4621    .. inheritance-diagram:: EdosFile
4622    """
4623
4624    def __init__(self, filepath):
4625        """
4626        Parses the EDOS file and construct self.edos object."""
4627
4628        super().__init__(filepath)
4629
4630        # Fortran implementation (eV units). See edos_write in m_ebands.F90.
4631        #
4632        # write(unt,"(a)")"# Energy           DOS_TOT          IDOS_TOT         DOS[spin=UP]     IDOS[spin=UP] ..."
4633        # do iw=1,edos%nw
4634        #   write(unt,'(es17.8)',advance='no')(edos%mesh(iw) - efermi) * cfact
4635        #   do spin=0,edos%nsppol
4636        #     write(unt,'(2es17.8)',advance='no')max(edos%dos(iw,spin) / cfact, tol30), max(edos%idos(iw,spin), tol30)
4637        #   end do
4638        #   write(unt,*)
4639        # end do
4640
4641        header = []
4642        data = []
4643
4644        for line in self:
4645            if line.startswith("#"):
4646                header.append(line)
4647            else:
4648                line = line.strip()
4649                if not line: continue
4650                data.append([float(v) for v in line.split()])
4651
4652        self.header_string = "".join(header)
4653        self.edos_params = find_yaml_section_in_lines(header, "EDOS_PARAMS")
4654        #print(self.edos_params)
4655        nelect = float(self.edos_params["nelect"])
4656        data = np.array(data).T.copy()
4657        mesh = data[0]
4658
4659        if len(data) == 5:
4660            # Spin unpolarized case.
4661            spin_dos = data[3]
4662            spin_idos = data[4]
4663        elif len(data) == 7:
4664            # Spin unpolarized case.
4665            spin_dos = data[[3, 5]]
4666            spin_idos = data[[4, 6]]
4667        else:
4668            raise ValueError("Don't know how to interpret %d columns in %s" % (len(data), filepath))
4669
4670        #print(mesh.shape, spin_dos.shape, spin_idos.shape)
4671        self.edos = ElectronDos(mesh, spin_dos, nelect, fermie=None, spin_idos=spin_idos)
4672
4673    def to_string(self, verbose=0):
4674        """String representation."""
4675        lines = [self.header_string]; app = lines.append
4676        app(self.edos.to_string(verbose=verbose))
4677
4678        return "\n".join(lines)
4679
4680    def yield_figs(self, **kwargs):  # pragma: no cover
4681        """
4682        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
4683        """
4684        yield self.edos.plot(show=False)
4685        yield self.edos.plot_dos_idos(show=False)
4686        if self.edos.nsppol == 2:
4687            yield self.edos.plot_up_minus_down(show=False)
4688
4689
4690class EdosRobot(Robot):
4691    """
4692    This robot analyzes the results contained in multiple EDOS files.
4693
4694    .. rubric:: Inheritance Diagram
4695    .. inheritance-diagram:: EdosRobot
4696    """
4697    EXT = "EDOS"
4698
4699    #def yield_figs(self, **kwargs):  # pragma: no cover
4700    #    """
4701    #    This function *generates* a predefined list of matplotlib figures with minimal input from the user.
4702    #    Used in abiview.py to get a quick look at the results.
4703    #    """
4704    #    yield self.plot_lattice_convergence(show=False)
4705    #    yield self.plot_gsr_convergence(show=False)
4706    #    for fig in self.get_ebands_plotter().yield_figs(): yield fig
4707