1# coding: utf-8
2import functools
3import numpy as np
4import itertools
5import pickle
6import os
7import json
8import warnings
9import abipy.core.abinit_units as abu
10
11from collections import OrderedDict
12from monty.string import is_string, list_strings, marquee
13from monty.collections import dict2namedtuple
14from monty.functools import lazy_property
15from monty.termcolor import cprint
16from pymatgen.core.units import eV_to_Ha, Energy
17from pymatgen.core.periodic_table import Element
18from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
19from pymatgen.phonon.dos import CompletePhononDos as PmgCompletePhononDos, PhononDos as PmgPhononDos
20from abipy.core.func1d import Function1D
21from abipy.core.mixins import AbinitNcFile, Has_Structure, Has_PhononBands, NotebookWriter
22from abipy.core.kpoints import Kpoint, Kpath, KpointList, kmesh_from_mpdivs
23from abipy.core.structure import Structure
24from abipy.abio.robots import Robot
25from abipy.iotools import ETSF_Reader
26from abipy.tools import duck
27from abipy.tools.numtools import gaussian, sort_and_groupby
28from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, set_axlims, get_axarray_fig_plt, set_visible, set_ax_xylabels
29from .phtk import match_eigenvectors, get_dyn_mat_eigenvec, open_file_phononwebsite, NonAnalyticalPh
30
31__all__ = [
32    "PhononBands",
33    "PhononBandsPlotter",
34    "PhbstFile",
35    "PhononDos",
36    "PhononDosPlotter",
37    "PhdosReader",
38    "PhdosFile",
39]
40
41
42@functools.total_ordering
43class PhononMode(object):
44    """
45    A phonon mode has a q-point, a frequency, a cartesian displacement and a |Structure|.
46    """
47
48    __slots__ = [
49        "qpoint",
50        "freq",
51        "displ_cart", # Cartesian displacement.
52        "structure"
53    ]
54
55    def __init__(self, qpoint, freq, displ_cart, structure):
56        """
57        Args:
58            qpoint: qpoint in reduced coordinates.
59            freq: Phonon frequency in eV.
60            displ: Displacement (Cartesian coordinates in Angstrom)
61            structure: |Structure| object.
62        """
63        self.qpoint = Kpoint.as_kpoint(qpoint, structure.reciprocal_lattice)
64        self.freq = freq
65        self.displ_cart = displ_cart
66        self.structure = structure
67
68    # Rich comparison support (ordered is based on the frequency).
69    # Missing operators are automatically filled by total_ordering.
70    def __eq__(self, other):
71        return self.freq == other.freq
72
73    def __lt__(self, other):
74        return self.freq < other.freq
75
76    def __str__(self):
77        return self.to_string(with_displ=False)
78
79    def to_string(self, with_displ=True, verbose=0):
80        """
81        String representation
82
83        Args:
84            verbose: Verbosity level.
85            with_displ: True to print phonon displacement.
86        """
87        lines = ["%s: q-point %s, frequency %.5f (eV)" % (self.__class__.__name__, self.qpoint, self.freq)]
88        app = lines.append
89
90        if with_displ:
91            app("Phonon displacement in cartesian coordinates [Angstrom]")
92            app(str(self.displ_cart))
93
94        return "\n".join(lines)
95
96    #@property
97    #def displ_red(self)
98    #    return np.dot(self.xred, self.rprimd)
99
100    #def export(self, path):
101    #def visualize(self, visualizer):
102    #def build_supercell(self):
103
104
105class PhononBands(object):
106    """
107    Container object storing the phonon band structure.
108
109    .. note::
110
111        Frequencies are in eV. Cartesian displacements are in Angstrom.
112    """
113    @classmethod
114    def from_file(cls, filepath):
115        """Create the object from a netcdf_ file."""
116        with PHBST_Reader(filepath) as r:
117            structure = r.read_structure()
118
119            # Build the list of q-points
120            qpoints = Kpath(structure.reciprocal_lattice, frac_coords=r.read_qredcoords(),
121                            weights=r.read_qweights(), names=None)
122
123            for qpoint in qpoints:
124                qpoint.set_name(structure.findname_in_hsym_stars(qpoint))
125
126            # Read amu
127            amu_list = r.read_amu()
128            if amu_list is not None:
129                atomic_numbers = r.read_value("atomic_numbers")
130                amu = {at: a for at, a in zip(atomic_numbers, amu_list)}
131            else:
132                cprint("Warning: file %s does not contain atomic_numbers.\nParticular methods need them!" %
133                       filepath, "red")
134                amu = None
135
136            non_anal_ph = None
137
138            # Reading NonAnalyticalPh here is not 100% safe as it may happen that the netcdf file
139            # does not contain all the directions required by AbiPy.
140            # So we read NonAnalyticalPh only if we know that all directions are available.
141            # The flag has_abipy_non_anal_ph is set at the Fortran level. See e.g ifc_mkphbs
142            if ("non_analytical_directions" in r.rootgrp.variables and
143                "has_abipy_non_anal_ph" in r.rootgrp.variables):
144                #print("Found non_anal_ph term compatible with AbiPy plotter.")
145                non_anal_ph = NonAnalyticalPh.from_file(filepath)
146
147            epsinf, zcart = r.read_epsinf_zcart()
148
149            return cls(structure=structure,
150                       qpoints=qpoints,
151                       phfreqs=r.read_phfreqs(),
152                       phdispl_cart=r.read_phdispl_cart(),
153                       amu=amu,
154                       non_anal_ph=non_anal_ph,
155                       epsinf=epsinf, zcart=zcart,
156                       )
157
158    @classmethod
159    def as_phbands(cls, obj):
160        """
161        Return an instance of |PhononBands| from a generic object ``obj``.
162        Supports:
163
164            - instances of cls
165            - files (string) that can be open with ``abiopen`` and that provide a ``phbands`` attribute.
166            - objects providing a ``phbands`` attribute.
167        """
168        if isinstance(obj, cls):
169            return obj
170
171        elif is_string(obj):
172            # path?
173            if obj.endswith(".pickle"):
174                with open(obj, "rb") as fh:
175                    return cls.as_phbands(pickle.load(fh))
176
177            from abipy.abilab import abiopen
178            with abiopen(obj) as abifile:
179                return abifile.phbands
180
181        elif hasattr(obj, "phbands"):
182            # object with phbands
183            return obj.phbands
184
185        raise TypeError("Don't know how to extract a PhononBands from type %s" % type(obj))
186
187    @staticmethod
188    def phfactor_ev2units(units):
189        """
190        Return conversion factor eV --> units (case-insensitive)
191        """
192        return abu.phfactor_ev2units(units)
193
194    def read_non_anal_from_file(self, filepath):
195        """
196        Reads the non analytical directions, frequencies and displacements from the anaddb.nc file
197        specified and adds them to the object.
198        """
199        self.non_anal_ph = NonAnalyticalPh.from_file(filepath)
200
201    def __init__(self, structure, qpoints, phfreqs, phdispl_cart, non_anal_ph=None, amu=None,
202                 epsinf=None, zcart=None, linewidths=None):
203        """
204        Args:
205            structure: |Structure| object.
206            qpoints: |KpointList| instance.
207            phfreqs: Phonon frequencies in eV.
208            phdispl_cart: [nqpt, 3*natom, 3*natom] array with displacement in Cartesian coordinates in Angstrom.
209                The last dimension stores the cartesian components.
210            non_anal_ph: :class:`NonAnalyticalPh` with information of the non analytical contribution
211                None if contribution is not present.
212            amu: dictionary that associates the atomic species present in the structure to the values of the atomic
213                mass units used for the calculation.
214            epsinf: [3,3] matrix with electronic dielectric tensor in Cartesian coordinates.
215                None if not avaiable.
216            zcart: [natom, 3, 3] matrix with Born effective charges in Cartesian coordinates.
217                None if not available.
218            linewidths: Array-like object with the linewidths (eV) stored as [q, num_modes]
219        """
220        self.structure = structure
221
222        # KpointList with the q-points
223        self.qpoints = qpoints
224        self.num_qpoints = len(self.qpoints)
225
226        # numpy array with phonon frequencies. Shape=(nqpt, 3*natom)
227        self.phfreqs = phfreqs
228
229        # phonon displacements in Cartesian coordinates.
230        # `ndarray` of shape (nqpt, 3*natom, 3*natom).
231        # The last dimension stores the cartesian components.
232        self.phdispl_cart = phdispl_cart
233
234        # Handy variables used to loop.
235        self.num_atoms = structure.num_sites
236        self.num_branches = 3 * self.num_atoms
237        self.branches = range(self.num_branches)
238
239        self.non_anal_ph = non_anal_ph
240        self.amu = amu
241        self.amu_symbol = None
242        if amu is not None:
243            self.amu_symbol = {}
244            for z, m in amu.items():
245                el = Element.from_Z(int(z))
246                self.amu_symbol[el.symbol] = m
247
248        self._linewidths = None
249        if linewidths is not None:
250            self._linewidths = np.reshape(linewidths, self.phfreqs.shape)
251
252        self.epsinf = epsinf
253        self.zcart = zcart
254
255        # Dictionary with metadata e.g. nkpt, tsmear ...
256        self.params = OrderedDict()
257
258    # TODO: Replace num_qpoints with nqpt, deprecate num_qpoints
259    @property
260    def nqpt(self):
261        """An alias for num_qpoints."""
262        return self.num_qpoints
263
264    def __repr__(self):
265        """String representation (short version)"""
266        return "<%s, nk=%d, %s, id=%s>" % (
267                self.__class__.__name__, self.num_qpoints, self.structure.formula, id(self))
268
269    def __str__(self):
270        return self.to_string()
271
272    def to_string(self, title=None, with_structure=True, with_qpoints=False, verbose=0):
273        """
274        Human-readable string with useful information such as structure, q-points, ...
275
276        Args:
277            with_structure: False if structural info should not be displayed.
278            with_qpoints: False if q-point info shoud not be displayed.
279            verbose: Verbosity level.
280        """
281        lines = []; app = lines.append
282        if title is not None: app(marquee(title, mark="="))
283
284        if with_structure:
285            app(self.structure.to_string(verbose=verbose, title="Structure"))
286            app("")
287
288        #app(marquee("Phonon Bands", mark="="))
289        app("Number of q-points: %d" % self.num_qpoints)
290        app("Atomic mass units: %s" % str(self.amu))
291        has_dipdip = self.non_anal_ph is not None
292        app("Has non-analytical contribution for q --> 0: %s" % has_dipdip)
293        if verbose and has_dipdip:
294            app(str(self.non_anal_ph))
295
296        if with_qpoints:
297            app(self.qpoints.to_string(verbose=verbose, title="Q-points"))
298            app("")
299
300        return "\n".join(lines)
301
302    def __add__(self, other):
303        """self + other returns a |PhononBandsPlotter| object."""
304        if not isinstance(other, (PhononBands, PhononBandsPlotter)):
305            raise TypeError("Cannot add %s to %s" % (type(self), type(other)))
306
307        if isinstance(other, PhononBandsPlotter):
308            self_key = repr(self)
309            other.add_phbands(self_key, self)
310            return other
311        else:
312            plotter = PhononBandsPlotter()
313            self_key = repr(self)
314            plotter.add_phbands(self_key, self)
315            self_key = repr(self)
316            other_key = repr(other)
317            plotter.add_phbands(other_key, other)
318            return plotter
319
320    __radd__ = __add__
321
322    @lazy_property
323    def _auto_qlabels(self):
324        # Find the q-point names in the pymatgen database.
325        # We'll use _auto_qlabels to label the point in the matplotlib plot
326        # if qlabels are not specified by the user.
327        _auto_qlabels = OrderedDict()
328
329        # If the first or the last q-point are not recognized in findname_in_hsym_stars
330        # matplotlib won't show the full band structure along the k-path
331        # because the labels are not defined. Here we make sure that
332        # the labels for the extrema of the path are always defined.
333        _auto_qlabels[0] = " "
334
335        for idx, qpoint in enumerate(self.qpoints):
336            name = qpoint.name if qpoint.name is not None else self.structure.findname_in_hsym_stars(qpoint)
337            if name is not None:
338                _auto_qlabels[idx] = name
339                if qpoint.name is None: qpoint.set_name(name)
340
341        last = len(self.qpoints) - 1
342        if last not in _auto_qlabels: _auto_qlabels[last] = " "
343
344        return _auto_qlabels
345
346    @property
347    def displ_shape(self):
348        """The shape of phdispl_cart."""
349        return self.phdispl_cart.shape
350
351    @property
352    def minfreq(self):
353        """Minimum phonon frequency."""
354        return self.get_minfreq_mode()
355
356    @property
357    def maxfreq(self):
358        """Maximum phonon frequency in eV."""
359        return self.get_maxfreq_mode()
360
361    def get_minfreq_mode(self, mode=None):
362        """Compute the minimum of the frequencies."""
363        if mode is None:
364            return np.min(self.phfreqs)
365        else:
366            return np.min(self.phfreqs[:, mode])
367
368    def get_maxfreq_mode(self, mode=None):
369        """Compute the minimum of the frequencies."""
370        if mode is None:
371            return np.max(self.phfreqs)
372        else:
373            return np.max(self.phfreqs[:, mode])
374
375    @property
376    def shape(self):
377        """Shape of the array with the eigenvalues."""
378        return self.num_qpoints, self.num_branches
379
380    @property
381    def linewidths(self):
382        """linewidths in eV. |numpy-array| with shape [nqpt, num_branches]."""
383        return self._linewidths
384
385    @linewidths.setter
386    def linewidths(self, linewidths):
387        """Set the linewidths. Accept real array of shape [nqpt, num_branches] or None."""
388        if linewidths is not None:
389            linewidths = np.reshape(linewidths, self.shape)
390        self._linewidths = linewidths
391
392    @property
393    def has_linewidths(self):
394        """True if bands with linewidths."""
395        return getattr(self, "_linewidths", None) is not None
396
397    @lazy_property
398    def dyn_mat_eigenvect(self):
399        """
400        [nqpt, 3*natom, 3*natom] array with the orthonormal eigenvectors of the dynamical matrix.
401        in Cartesian coordinates.
402        """
403        return get_dyn_mat_eigenvec(self.phdispl_cart, self.structure, amu=self.amu)
404
405    @property
406    def non_anal_directions(self):
407        """Cartesian directions along which the non analytical frequencies and displacements are available"""
408        if self.non_anal_ph:
409            return self.non_anal_ph.directions
410        else:
411            return None
412
413    @property
414    def non_anal_phfreqs(self):
415        """Phonon frequencies with non analytical contribution in eV along non_anal_directions"""
416        if self.non_anal_ph:
417            return self.non_anal_ph.phfreqs
418        else:
419            return None
420
421    @property
422    def non_anal_phdispl_cart(self):
423        """Displacement in Cartesian coordinates with non analytical contribution along non_anal_directions"""
424        if self.non_anal_ph:
425            return self.non_anal_ph.phdispl_cart
426        else:
427            return None
428
429    @property
430    def non_anal_dyn_mat_eigenvect(self):
431        """Eigenvalues of the dynamical matrix with non analytical contribution along non_anal_directions."""
432        if self.non_anal_ph:
433            return self.non_anal_ph.dyn_mat_eigenvect
434        else:
435            return None
436
437    def to_xmgrace(self, filepath, units="meV"):
438        """
439        Write xmgrace_ file with phonon band structure energies and labels for high-symmetry q-points.
440
441        Args:
442            filepath: String with filename or stream.
443            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
444                Case-insensitive.
445        """
446        is_stream = hasattr(filepath, "write")
447
448        if is_stream:
449            f = filepath
450        else:
451            f = open(filepath, "wt")
452
453        def w(s):
454            f.write(s)
455            f.write("\n")
456
457        factor = abu.phfactor_ev2units(units)
458        wqnu_units = self.phfreqs * factor
459
460        import datetime
461        w("# Grace project file with phonon band energies.")
462        w("# Generated by AbiPy on: %s" % str(datetime.datetime.today()))
463        w("# Crystalline structure:")
464        for s in str(self.structure).splitlines():
465            w("# %s" % s)
466        w("# Energies are in %s." % units)
467        w("# List of q-points and their index (C notation i.e. count from 0)")
468        for iq, qpt in enumerate(self.qpoints):
469            w("# %d %s" % (iq, str(qpt.frac_coords)))
470        w("@page size 792, 612")
471        w("@page scroll 5%")
472        w("@page inout 5%")
473        w("@link page off")
474        w("@with g0")
475        w("@world xmin 0.00")
476        w('@world xmax %d' % (self.num_qpoints - 1))
477        w('@world ymin %s' % wqnu_units.min())
478        w('@world ymax %s' % wqnu_units.max())
479        w('@default linewidth 1.5')
480        w('@xaxis  tick on')
481        w('@xaxis  tick major 1')
482        w('@xaxis  tick major color 1')
483        w('@xaxis  tick major linestyle 3')
484        w('@xaxis  tick major grid on')
485        w('@xaxis  tick spec type both')
486        w('@xaxis  tick major 0, 0')
487
488        qticks, qlabels = self._make_ticks_and_labels(qlabels=None)
489        w('@xaxis  tick spec %d' % len(qticks))
490        for iq, (qtick, qlabel) in enumerate(zip(qticks, qlabels)):
491            w('@xaxis  tick major %d, %d' % (iq, qtick))
492            w('@xaxis  ticklabel %d, "%s"' % (iq, qlabel))
493
494        w('@xaxis  ticklabel char size 1.500000')
495        w('@yaxis  tick major 10')
496        w('@yaxis  label "Phonon %s"' % abu.phunit_tag(units))
497        w('@yaxis  label char size 1.500000')
498        w('@yaxis  ticklabel char size 1.500000')
499        for nu in self.branches:
500            w('@    s%d line color %d' % (nu, 1))
501
502        # TODO: support LO-TO splitting (?)
503        for nu in self.branches:
504            w('@target G0.S%d' % nu)
505            w('@type xy')
506            for iq in range(self.num_qpoints):
507                w('%d %.8E' % (iq, wqnu_units[iq, nu]))
508            w('&')
509
510        if not is_stream:
511            f.close()
512
513    # TODO
514    #def to_bxsf(self, filepath):
515    #    """
516    #    Export the full band structure to `filepath` in BXSF format
517    #    suitable for the visualization of isosurfaces with Xcrysden (xcrysden --bxsf FILE).
518    #    Require q-points in IBZ and gamma-centered q-mesh.
519    #    """
520    #    self.get_phbands3d().to_bxsf(filepath)
521
522    #def get_phbands3d(self):
523    #    has_timrev, fermie = True, 0.0
524    #    return PhononBands3D(self.structure, self.qpoints, has_timrev, self.phfreqs, fermie)
525
526    def qindex(self, qpoint):
527        """Returns the index of the qpoint. Accepts integer or reduced coordinates."""
528        if duck.is_intlike(qpoint):
529            return int(qpoint)
530        else:
531            return self.qpoints.index(qpoint)
532
533    def qindex_qpoint(self, qpoint, is_non_analytical_direction=False):
534        """
535        Returns (qindex, qpoint) from an integer or a qpoint.
536
537        Args:
538            qpoint: integer, vector of reduced coordinates or |Kpoint| object.
539            is_non_analytical_direction: True if qpoint should be interpreted as a fractional direction for q --> 0
540                In this case qindex refers to the index of the direction in the :class:`NonAnalyticalPh` object.
541        """
542        if not is_non_analytical_direction:
543            # Standard search in qpoints.
544            qindex = self.qindex(qpoint)
545            return qindex, self.qpoints[qindex]
546        else:
547            # Find index of direction given by qpoint.
548            if self.non_anal_ph is None:
549                raise ValueError("Phononbands does not contain non-analytical terms for q-->0")
550
551            # Extract direction (assumed in fractional coordinates)
552            if hasattr(qpoint, "frac_coords"):
553                direction = qpoint.frac_coords
554            elif duck.is_intlike(qpoint):
555                direction = self.non_anal_ph.directions[qpoint]
556            else:
557                direction = qpoint
558
559            qindex = self.non_anal_ph.index_direction(direction, cartesian=False)
560
561            # Convert to fractional coords.
562            cart_direc = self.non_anal_ph.directions[qindex]
563            red_direc = self.structure.reciprocal_lattice.get_fractional_coords(cart_direc)
564            qpoint = Kpoint(red_direc, self.structure.reciprocal_lattice, weight=None, name=None)
565
566            return qindex, qpoint
567
568    def get_unstable_modes(self, below_mev=-5.0):
569        """
570        Return a list of :class:`PhononMode` objects with the unstable modes.
571        A mode is unstable if its frequency is < below_mev. Output list is sorted
572        and modes with lowest frequency come first.
573        """
574        umodes = []
575
576        for iq, qpoint in enumerate(self.qpoints):
577            for nu in self.branches:
578                freq = self.phfreqs[iq, nu]
579                if freq < below_mev / 1000:
580                    displ_cart = self.phdispl_cart[iq, nu, :]
581                    umodes.append(PhononMode(qpoint, freq, displ_cart, self.structure))
582
583        return sorted(umodes)
584
585    # TODO
586    #def find_irreps(self, qpoint, tolerance):
587    #    """
588    #    Find the irreducible representation at this q-point
589    #    Raise: QIrrepsError if algorithm fails
590    #    """
591    #    qindex, qpoint = self.qindex_qpoint(qpoint)
592
593    def get_dict4pandas(self, with_spglib=True):
594        """
595        Return a :class:`OrderedDict` with the most important parameters:
596
597            - Chemical formula and number of atoms.
598            - Lattice lengths, angles and volume.
599            - The spacegroup number computed by Abinit (set to None if not available).
600            - The spacegroup number and symbol computed by spglib (set to None not `with_spglib`).
601
602        Useful to construct pandas DataFrames
603
604        Args:
605            with_spglib: If True, spglib_ is invoked to get the spacegroup symbol and number
606        """
607        odict = OrderedDict([
608            ("nqpt", self.num_qpoints), ("nmodes", self.num_branches),
609            ("min_freq", self.minfreq), ("max_freq", self.maxfreq),
610            ("mean_freq", self.phfreqs.mean()), ("std_freq", self.phfreqs.std())
611
612        ])
613        odict.update(self.structure.get_dict4pandas(with_spglib=with_spglib))
614
615        return odict
616
617    def get_phdos(self, method="gaussian", step=1.e-4, width=4.e-4):
618        """
619        Compute the phonon DOS on a linear mesh.
620
621        Args:
622            method: String defining the method
623            step: Energy step (eV) of the linear mesh.
624            width: Standard deviation (eV) of the gaussian.
625
626        Returns:
627            |PhononDos| object.
628
629        .. warning::
630
631            Requires a homogeneous sampling of the Brillouin zone.
632        """
633        if abs(self.qpoints.sum_weights() - 1) > 1.e-6:
634            raise ValueError("Qpoint weights should sum up to one")
635
636        # Compute the linear mesh for the DOS
637        w_min = self.minfreq
638        w_min -= 0.1 * abs(w_min)
639        w_max = self.maxfreq
640        w_max += 0.1 * abs(w_max)
641        nw = 1 + (w_max - w_min) / step
642
643        mesh, step = np.linspace(w_min, w_max, num=nw, endpoint=True, retstep=True)
644
645        values = np.zeros(nw)
646        if method == "gaussian":
647            for q, qpoint in enumerate(self.qpoints):
648                weight = qpoint.weight
649                for nu in self.branches:
650                    w = self.phfreqs[q, nu]
651                    values += weight * gaussian(mesh, width, center=w)
652
653        else:
654            raise ValueError("Method %s is not supported" % str(method))
655
656        return PhononDos(mesh, values)
657
658    def create_xyz_vib(self, iqpt, filename, pre_factor=200, do_real=True, scale_matrix=None, max_supercell=None):
659        """
660        Create vibration XYZ file for visualization of phonons.
661
662        Args:
663            iqpt: index of qpoint.
664            filename: name of the XYZ file that will be created.
665            pre_factor: Multiplication factor of the displacements.
666            do_real: True if we want only real part of the displacement, False means imaginary part.
667            scale_matrix: Scaling matrix of the supercell.
668            max_supercell: Maximum size of the supercell with respect to primitive cell.
669        """
670        if scale_matrix is None:
671            if max_supercell is None:
672                raise ValueError("If scale_matrix is None, max_supercell must be provided!")
673
674            scale_matrix = self.structure.get_smallest_supercell(self.qpoints[iqpt].frac_coords,
675                                                                 max_supercell=max_supercell)
676
677        natoms = int(np.round(len(self.structure) * np.linalg.det(scale_matrix)))
678
679        with open(filename, "wt") as xyz_file:
680            for imode in np.arange(self.num_branches):
681                xyz_file.write(str(natoms) + "\n")
682                xyz_file.write("Mode " + str(imode) + " : " + str(self.phfreqs[iqpt, imode]) + "\n")
683                self.structure.write_vib_file(
684                    xyz_file, self.qpoints[iqpt].frac_coords,
685                    pre_factor * np.reshape(self.phdispl_cart[iqpt, imode,:],(-1,3)),
686                    do_real=True, frac_coords=False, max_supercell=max_supercell, scale_matrix=scale_matrix)
687
688    def create_ascii_vib(self, iqpts, filename, pre_factor=1):
689        """
690        Create vibration ascii file for visualization of phonons.
691        This format can be read with v_sim_ or ascii-phonons.
692
693        Args:
694            iqpts: an index or a list of indices of the qpoints in self. Note that at present only V_sim supports
695                an ascii file with multiple qpoints.
696            filename: name of the ascii file that will be created.
697            pre_factor: Multiplication factor of the displacements.
698        """
699        if not isinstance(iqpts, (list, tuple)):
700            iqpts = [iqpts]
701
702        structure = self.structure
703        a, b, c = structure.lattice.abc
704        alpha, beta, gamma = (np.pi*a/180 for a in structure.lattice.angles)
705        m = structure.lattice.matrix
706        sign = np.sign(np.dot(np.cross(m[0], m[1]), m[2]))
707
708        dxx = a
709        dyx = b * np.cos(gamma)
710        dyy = b * np.sin(gamma)
711        dzx = c * np.cos(beta)
712        dzy = c * (np.cos(alpha) - np.cos(gamma) * np.cos(beta)) / np.sin(gamma)
713        # keep the same orientation
714        dzz = sign*np.sqrt(c**2-dzx**2-dzy**2)
715
716        lines = ["# ascii file generated with abipy"]
717        lines.append("  {: 3.10f}  {: 3.10f}  {: 3.10f}".format(dxx, dyx, dyy))
718        lines.append("  {: 3.10f}  {: 3.10f}  {: 3.10f}".format(dzx, dzy, dzz))
719
720        # use reduced coordinates
721        lines.append("#keyword: reduced")
722
723        # coordinates
724        for s in structure:
725            lines.append("  {: 3.10f}  {: 3.10f}  {: 3.10f} {:>2}".format(s.a, s.b, s.c, s.specie.name))
726
727        ascii_basis = [[dxx, 0, 0],
728                       [dyx, dyy, 0],
729                       [dzx, dzy, dzz]]
730
731        for iqpt in iqpts:
732            q = self.qpoints[iqpt].frac_coords
733
734            displ_list = np.zeros((self.num_branches, self.num_atoms, 3), dtype=complex)
735            for i in range(self.num_atoms):
736                displ_list[:,i,:] = self.phdispl_cart[iqpt,:,3*i:3*(i+1)] * \
737                    np.exp(-2*np.pi*1j*np.dot(structure[i].frac_coords, self.qpoints[iqpt].frac_coords))
738
739            displ_list = np.dot(np.dot(displ_list, structure.lattice.inv_matrix), ascii_basis) * pre_factor
740
741            for imode in np.arange(self.num_branches):
742                lines.append("#metaData: qpt=[{:.6f};{:.6f};{:.6f};{:.6f} \\".format(
743                    q[0], q[1], q[2], self.phfreqs[iqpt, imode]))
744
745                for displ in displ_list[imode]:
746                    line = "#; " + "; ".join("{:.6f}".format(i) for i in displ.real) + "; " \
747                           + "; ".join("{:.6f}".format(i) for i in displ.imag) + " \\"
748                    lines.append(line)
749
750                lines.append(("# ]"))
751
752        with open(filename, 'wt') as f:
753            f.write("\n".join(lines))
754
755    def view_phononwebsite(self, browser=None, verbose=0, dryrun=False, **kwargs):
756        """
757        Produce JSON_ file that can be parsed from the phononwebsite_ and open it in ``browser``.
758
759        Args:
760            browser: Open webpage in ``browser``. Use default $BROWSER if None.
761            verbose: Verbosity level
762            dryrun: Activate dryrun mode for unit testing purposes.
763            kwargs: Passed to create_phononwebsite_json method
764
765        Return: Exit status
766        """
767        # Create json in abipy_nbworkdir with relative path so that we can read it inside the browser.
768        from abipy.core.globals import abinb_mkstemp
769        prefix = self.structure.formula.replace(" ", "")
770        _, rpath = abinb_mkstemp(force_abinb_workdir=not dryrun, use_relpath=True,
771                                 prefix=prefix, suffix=".json", text=True)
772
773        if verbose: print("Writing json file:", rpath)
774        self.create_phononwebsite_json(rpath, indent=None, **kwargs)
775
776        if dryrun: return 0
777        return open_file_phononwebsite(rpath, browser=browser)
778
779    def create_phononwebsite_json(self, filename, name=None, repetitions=None, highsym_qpts=None,
780                                  match_bands=True, highsym_qpts_mode="std", indent=2):
781        """
782        Writes a JSON_ file that can be parsed from the phononwebsite_.
783
784        Args:
785            filename: name of the json file that will be created
786            name: name associated with the data.
787            repetitions: number of repetitions of the cell. List of three integers. Defaults to [3,3,3].
788            highsym_qpts: list of tuples. The first element of each tuple should be a list with the coordinates
789                of a high symmetry point, the second element of the tuple should be its label.
790            match_bands: if True tries to follow the band along the path based on the scalar product of the eigenvectors.
791            highsym_qpts_mode: if ``highsym_qpts`` is None, high symmetry q-points can be automatically determined.
792                Accepts the following values:
793                'split' will split the path based on points where the path changes direction in the Brillouin zone.
794                Similar to what is done in phononwebsite. Only Gamma will be labeled.
795                'std' uses the standard generation procedure for points and labels used in PhononBands.
796                None does not set any point.
797            indent: Indentation level, passed to json.dump
798        """
799
800        def split_non_collinear(qpts):
801            r"""
802            function that splits the list of qpoints at repetitions (only the first point will be considered as
803            high symm) and where the direction changes. Also sets :math:`\Gamma` for [0, 0, 0].
804            Similar to what is done in phononwebsite_.
805            """
806            h = []
807            if np.array_equal(qpts[0], [0, 0, 0]):
808                h.append((0, "\\Gamma"))
809            for i in range(1, len(qpts)-1):
810                if np.array_equal(qpts[i], [0,0,0]):
811                    h.append((i, "\\Gamma"))
812                elif np.array_equal(qpts[i], qpts[i+1]):
813                    h.append((i, ""))
814                else:
815                    v1 = [a_i - b_i for a_i, b_i in zip(qpts[i+1], qpts[i])]
816                    v2 = [a_i - b_i for a_i, b_i in zip(qpts[i-1], qpts[i])]
817                    if not np.isclose(np.linalg.det([v1,v2,[1,1,1]]), 0):
818                        h.append((i, ""))
819            if np.array_equal(qpts[-1], [0, 0, 0]):
820                h.append((len(qpts)-1, "\\Gamma"))
821
822            return h
823
824        def reasonable_repetitions(natoms):
825            if (natoms < 4): return (3, 3, 3)
826            if (4 < natoms < 50): return (2, 2, 2)
827            if (50 < natoms): return (1, 1, 1)
828
829        # http://henriquemiranda.github.io/phononwebsite/index.html
830        data = {}
831        data["name"] = name or self.structure.composition.reduced_formula
832        data["natoms"] = self.num_atoms
833        data["lattice"] = self.structure.lattice.matrix.tolist()
834        data["atom_types"] = [e.name for e in self.structure.species]
835        data["atom_numbers"] = self.structure.atomic_numbers
836        data["formula"] = self.structure.formula.replace(" ", "")
837        data["repetitions"] = repetitions or reasonable_repetitions(self.num_atoms)
838        data["atom_pos_car"] = self.structure.cart_coords.tolist()
839        data["atom_pos_red"] = self.structure.frac_coords.tolist()
840        data["chemical_symbols"] = self.structure.symbol_set
841        data["atomic_numbers"] = list(set(self.structure.atomic_numbers))
842
843        qpoints = []
844        for q_sublist in self.split_qpoints:
845            qpoints.extend(q_sublist.tolist())
846
847        if highsym_qpts is None:
848            if highsym_qpts_mode is None:
849                data["highsym_qpts"] = []
850            elif highsym_qpts_mode == 'split':
851                data["highsym_qpts"] = split_non_collinear(qpoints)
852            elif highsym_qpts_mode == 'std':
853                data["highsym_qpts"] = list(zip(*self._make_ticks_and_labels(None)))
854        else:
855            data["highsym_qpts"] = highsym_qpts
856
857        distances = [0]
858        for i in range(1, len(qpoints)):
859            q_coord_1 = self.structure.reciprocal_lattice.get_cartesian_coords(qpoints[i])
860            q_coord_2 = self.structure.reciprocal_lattice.get_cartesian_coords(qpoints[i-1])
861            distances.append(distances[-1] + np.linalg.norm(q_coord_1-q_coord_2))
862
863        eigenvalues = []
864        for i, phfreqs_sublist in enumerate(self.split_phfreqs):
865            phfreqs_sublist = phfreqs_sublist * eV_to_Ha * abu.Ha_cmm1
866            if match_bands:
867                ind = self.split_matched_indices[i]
868                phfreqs_sublist = phfreqs_sublist[np.arange(len(phfreqs_sublist))[:, None], ind]
869            eigenvalues.extend(phfreqs_sublist.tolist())
870
871        vectors = []
872
873        for i, (qpts, phdispl_sublist) in enumerate(zip(self.split_qpoints, self.split_phdispl_cart)):
874            vect = np.array(phdispl_sublist)
875
876            if match_bands:
877                vect = vect[np.arange(vect.shape[0])[:, None, None],
878                            self.split_matched_indices[i][...,None],
879                            np.arange(vect.shape[2])[None, None,:]]
880            v = vect.reshape((len(vect), self.num_branches,self.num_atoms, 3))
881            norm = [np.linalg.norm(vi) for vi in v[0,0]]
882            v /= max(norm)
883            v = np.stack([v.real, v.imag], axis=-1)
884
885            vectors.extend(v.tolist())
886
887        data["qpoints"] = qpoints
888        data["distances"] = distances
889        data["eigenvalues"] = eigenvalues
890        data["vectors"] = vectors
891        #print("name", data["name"], "\nhighsym_qpts:", data["highsym_qpts"])
892
893        with open(filename, 'wt') as json_file:
894            json.dump(data, json_file, indent=indent)
895
896    def make_isodistort_ph_dir(self, qpoint, select_modes=None, eta=1, workdir=None):
897        """
898        Compute ph-freqs for given q-point (default: Gamma),
899        produce CIF files for unperturbed and distorded structure
900        that can be used with ISODISTORT (https://stokes.byu.edu/iso/isodistort.php)
901        to analyze the symmetry of phonon modes.
902        See README.me file produced in output directory.
903
904        Args:
905            qpoint:
906            wordir:
907            select_modes:
908            eta: Amplitude of the displacement to be applied to the system. Will correspond to the
909                largest displacement of one atom in Angstrom.
910            scale_matrix: the scaling matrix of the supercell. If None a scaling matrix suitable for
911                the qpoint will be determined.
912            max_supercell: mandatory if scale_matrix is None, ignored otherwise. Defines the largest
913                supercell in the search for a scaling matrix suitable for the q point.
914        """
915        iq, qpoint = self.qindex_qpoint(qpoint)
916
917        scale_matrix = np.eye(3, 3, dtype=int)
918        important_fracs = (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
919        for i in range(3):
920            for comparison_frac in important_fracs:
921                if abs(1 - qpoint.frac_coords[i] * comparison_frac) < 1e-4:
922                    scale_matrix[i, i] = comparison_frac
923                    break
924        print(f"Using scale_matrix:\n {scale_matrix}")
925
926        select_modes = self.branches if select_modes is None else select_modes
927        if workdir is None:
928            workdir = "%s_qpt%s" % (self.structure.formula, repr(qpoint))
929            workdir = workdir.replace(" ", "_").replace("$", "").replace("\\", "").replace("[", "").replace("]", "")
930
931        if os.path.isdir(workdir):
932            cprint(f"Removing pre-existing directory: {workdir}", "yellow")
933            import shutil
934            shutil.rmtree(workdir)
935
936        os.mkdir(workdir)
937
938        print(f"\nCreating CIF files for ISODISTORT code in {workdir}. See README.md")
939        self.structure.write_cif_with_spglib_symms(filename=os.path.join(workdir, "parent_structure.cif"))
940
941        for imode in select_modes:
942            # A namedtuple with a structure with the displaced atoms, a numpy array containing the
943            # displacements applied to each atom and the scale matrix used to generate the supercell.
944            r = self.get_frozen_phonons(qpoint, imode,
945                                        eta=eta, scale_matrix=scale_matrix, max_supercell=None)
946
947            print("after scale_matrix:", r.scale_matrix)
948            r.structure.write_cif_with_spglib_symms(filename=os.path.join(workdir,
949                                                    "distorted_structure_mode_%d.cif" % (imode + 1)))
950
951        readme_string = """
952
953Use Harold Stokes' code, [ISODISTORT](https://stokes.byu.edu/iso/isodistort.php),
954loading in your structure that you did the DFPT calculation as the **parent**,
955then, select mode decompositional analysis and upload the cif file from step (3).
956
957Follow the on screen instructions.
958You will then be presented with the mode irrep and other important symmetry information.
959
960Thanks to Jack Baker for pointing out this approach.
961See also <https://forum.abinit.org/viewtopic.php?f=10&t=545>
962"""
963        with open(os.path.join(workdir, "README.md"), "wt") as fh:
964            fh.write(readme_string)
965
966        return workdir
967
968    def decorate_ax(self, ax, units='eV', **kwargs):
969        """
970        Add q-labels, title and unit name to axis ax.
971        Use units="" to add k-labels without unit name.
972
973        Args:
974            title:
975            fontsize
976            qlabels:
977            qlabel_size:
978        """
979        title = kwargs.pop("title", None)
980        fontsize = kwargs.pop("fontsize", 12)
981        if title is not None: ax.set_title(title, fontsize=fontsize)
982        ax.grid(True)
983
984        # Handle conversion factor.
985        if units:
986            ax.set_ylabel(abu.wlabel_from_units(units))
987
988        ax.set_xlabel("Wave Vector")
989
990        # Set ticks and labels.
991        ticks, labels = self._make_ticks_and_labels(kwargs.pop("qlabels", None))
992        if ticks:
993            # Don't show label if previous k-point is the same.
994            for il in range(1, len(labels)):
995                if labels[il] == labels[il-1]: labels[il] = ""
996            ax.set_xticks(ticks, minor=False)
997            ax.set_xticklabels(labels, fontdict=None, minor=False, size=kwargs.pop("qlabel_size", "large"))
998            #print("ticks", len(ticks), ticks)
999            ax.set_xlim(ticks[0], ticks[-1])
1000
1001    @add_fig_kwargs
1002    def plot(self, ax=None, units="eV", qlabels=None, branch_range=None, match_bands=False, temp=None,
1003             fontsize=12, **kwargs):
1004        r"""
1005        Plot the phonon band structure.
1006
1007        Args:
1008            ax: |matplotlib-Axes| or None if a new figure should be created.
1009            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
1010                Case-insensitive.
1011            qlabels: dictionary whose keys are tuples with the reduced coordinates of the q-points.
1012                The values are the labels. e.g. ``qlabels = {(0.0,0.0,0.0): "$\Gamma$", (0.5,0,0): "L"}``.
1013            branch_range: Tuple specifying the minimum and maximum branch index to plot (default: all branches are plotted).
1014            match_bands: if True the bands will be matched based on the scalar product between the eigenvectors.
1015            temp: Temperature in Kelvin. If not None, a scatter plot with the Bose-Einstein occupation factor
1016                at temperature `temp` is added.
1017            fontsize: Legend and title fontsize.
1018
1019        Returns: |matplotlib-Figure|
1020        """
1021        # Select the band range.
1022        branch_range = range(self.num_branches) if branch_range is None else \
1023                       range(branch_range[0], branch_range[1], 1)
1024
1025        ax, fig, plt = get_ax_fig_plt(ax=ax)
1026
1027        # Decorate the axis (e.g. add ticks and labels).
1028        self.decorate_ax(ax, units=units, qlabels=qlabels)
1029
1030        if "color" not in kwargs: kwargs["color"] = "black"
1031        if "linewidth" not in kwargs: kwargs["linewidth"] = 2.0
1032
1033        # Plot the phonon branches.
1034        self.plot_ax(ax, branch_range, units=units, match_bands=match_bands, **kwargs)
1035
1036        if temp is not None:
1037            # Scatter plot with Bose-Einstein occupation factors for T = temp
1038            factor = abu.phfactor_ev2units(units)
1039            if temp < 1: temp = 1
1040            ax.set_title("T = %.1f K" % temp, fontsize=fontsize)
1041            xs = np.arange(self.num_qpoints)
1042            for nu in self.branches:
1043                ws = self.phfreqs[:, nu]
1044                wkt = self.phfreqs[:, nu] / (abu.kb_eVK * temp)
1045                # 1 / (np.exp(1e-6) - 1)) ~ 999999.5
1046                wkt = np.where(wkt > 1e-6, wkt, 1e-6)
1047                occ = 1.0 / (np.exp(wkt) - 1.0)
1048                s = np.where(occ < 2, occ, 2) * 50
1049                ax.scatter(xs, ws * factor, s=s, marker="o", c="b", alpha=0.6)
1050                #ax.scatter(xs, ws, s=s, marker="o", c=occ, cmap="jet")
1051
1052        return fig
1053
1054    def plot_ax(self, ax, branch, units='eV', match_bands=False, **kwargs):
1055        """
1056        Plots the frequencies for the given branches indices as a function of the q-index on axis ``ax``.
1057        If ``branch`` is None, all phonon branches are plotted.
1058
1059        Return: The list of matplotlib lines added.
1060        """
1061        if branch is None:
1062            branch_range = range(self.num_branches)
1063        elif isinstance(branch, (list, tuple, np.ndarray)):
1064            branch_range = branch
1065        else:
1066            branch_range = [branch]
1067
1068        first_xx = 0
1069        lines = []
1070
1071        factor = abu.phfactor_ev2units(units)
1072
1073        for i, pf in enumerate(self.split_phfreqs):
1074            if match_bands:
1075                ind = self.split_matched_indices[i]
1076                pf = pf[np.arange(len(pf))[:, None], ind]
1077            pf = pf * factor
1078            xx = list(range(first_xx, first_xx + len(pf)))
1079            for branch in branch_range:
1080                lines.extend(ax.plot(xx, pf[:, branch], **kwargs))
1081            first_xx = xx[-1]
1082
1083        return lines
1084
1085    @add_fig_kwargs
1086    def plot_colored_matched(self, ax=None, units="eV", qlabels=None, branch_range=None,
1087                             colormap="rainbow", max_colors=None, **kwargs):
1088        r"""
1089        Plot the phonon band structure with different colors for each line.
1090
1091        Args:
1092            ax: |matplotlib-Axes| or None if a new figure should be created.
1093            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
1094                Case-insensitive.
1095            qlabels: dictionary whose keys are tuples with the reduced coordinates of the q-points.
1096                The values are the labels. e.g. ``qlabels = {(0.0,0.0,0.0): "$\Gamma$", (0.5,0,0): "L"}``.
1097            branch_range: Tuple specifying the minimum and maximum branch_i index to plot
1098                (default: all branches are plotted).
1099            colormap: matplotlib colormap to determine the colors available. The colors will be chosen not in a
1100                sequential order to avoid difficulties in distinguishing the lines.
1101                http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html
1102            max_colors: maximum number of colors to be used. If max_colors < num_braches the colors will be reapeated.
1103                It useful to better distinguish close bands when the number of branch is large.
1104
1105        Returns: |matplotlib-Figure|
1106        """
1107        # Select the band range.
1108        if branch_range is None:
1109            branch_range = range(self.num_branches)
1110        else:
1111            branch_range = range(branch_range[0], branch_range[1], 1)
1112
1113        ax, fig, plt = get_ax_fig_plt(ax=ax)
1114
1115        # Decorate the axis (e.g add ticks and labels).
1116        self.decorate_ax(ax, units=units, qlabels=qlabels)
1117
1118        first_xx = 0
1119        lines = []
1120        factor = abu.phfactor_ev2units(units)
1121
1122        if max_colors is None:
1123            max_colors = len(branch_range)
1124
1125        colormap = plt.get_cmap(colormap)
1126
1127        for i, pf in enumerate(self.split_phfreqs):
1128            ind = self.split_matched_indices[i]
1129            pf = pf[np.arange(len(pf))[:, None], ind]
1130            pf = pf * factor
1131            xx = range(first_xx, first_xx + len(pf))
1132            colors = itertools.cycle(colormap(np.linspace(0, 1, max_colors)))
1133            for branch_i in branch_range:
1134                kwargs = dict(kwargs)
1135                kwargs['color'] = next(colors)
1136                lines.extend(ax.plot(xx, pf[:, branch_i], **kwargs))
1137            first_xx = xx[-1]
1138
1139        return fig
1140
1141    @add_fig_kwargs
1142    def plot_lt_character(self, units="eV", qlabels=None, ax=None, xlims=None, ylims=None,
1143                          colormap="jet", fontsize=12, **kwargs):
1144        r"""
1145        Plot the phonon band structure with colored lines. The color of the lines indicates
1146        the degree to which the mode is longitudinal:
1147        Red corresponds to longitudinal modes and black to purely transverse modes.
1148
1149        Args:
1150            ax: |matplotlib-Axes| or None if a new figure should be created.
1151            units: Units for plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
1152                Case-insensitive.
1153            qlabels: dictionary whose keys are tuples with the reduced coordinates of the q-points.
1154                The values are the labels. e.g. ``qlabels = {(0.0,0.0,0.0): "$\Gamma$", (0.5,0,0): "L"}``.
1155            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
1156                   or scalar e.g. ``left``. If left (right) is None, default values are used.
1157            ylims: y-axis limits.
1158            colormap: Matplotlib colormap.
1159            fontsize: legend and title fontsize.
1160
1161        Returns: |matplotlib-Figure|
1162        """
1163        if self.zcart is None:
1164            cprint("Bandstructure does not have Born effective charges", "yellow")
1165            return None
1166
1167        factor = abu.phfactor_ev2units(units)
1168        ax, fig, plt = get_ax_fig_plt(ax=ax)
1169        cmap = plt.get_cmap(colormap)
1170
1171        if "color" not in kwargs: kwargs["color"] = "black"
1172        if "linewidth" not in kwargs: kwargs["linewidth"] = 2.0
1173
1174        first_xx = 0
1175        scatt_x, scatt_y, scatt_s = [], [], []
1176        for p_qpts, p_freqs, p_dcart in zip(self.split_qpoints, self.split_phfreqs, self.split_phdispl_cart):
1177            xx = list(range(first_xx, first_xx + len(p_freqs)))
1178
1179            for iq, (qpt, ws, dis) in enumerate(zip(p_qpts, p_freqs, p_dcart)):
1180                qcart = self.structure.reciprocal_lattice.get_cartesian_coords(qpt)
1181                qnorm = np.linalg.norm(qcart)
1182                inv_qepsq = 0.0
1183                if qnorm > 1e-3:
1184                    qvers = qcart / qnorm
1185                    inv_qepsq = 1.0 / np.dot(qvers, np.dot(self.epsinf, qvers))
1186
1187                # We are not interested in the amplitudes so normalize all displacements to one.
1188                dis = dis.reshape(self.num_branches, self.num_atoms, 3)
1189                # q x Z[atom] x disp[q, nu, atom]
1190                for nu in range(self.num_branches):
1191                    v = sum(np.dot(qcart, np.dot(self.zcart[iatom], dis[nu, iatom])) for iatom in range(self.num_atoms))
1192                    scatt_x.append(xx[iq])
1193                    scatt_y.append(ws[nu])
1194                    scatt_s.append(v * inv_qepsq)
1195
1196            p_freqs = p_freqs * factor
1197            ax.plot(xx, p_freqs, **kwargs)
1198            first_xx = xx[-1]
1199
1200        scatt_y = np.array(scatt_y) * factor
1201        scatt_s = np.abs(np.array(scatt_s))
1202        scatt_s /= scatt_s.max()
1203        scatt_s *= 50
1204        print("scatt_s", scatt_s, "min", scatt_s.min(), "max", scatt_s.max())
1205
1206        ax.scatter(scatt_x, scatt_y, s=scatt_s,
1207            #c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None,
1208            #linewidths=None, verts=None, edgecolors=None, *, data=None
1209        )
1210        self.decorate_ax(ax, units=units, qlabels=qlabels)
1211        set_axlims(ax, xlims, "x")
1212        set_axlims(ax, ylims, "y")
1213
1214        return fig
1215
1216    @property
1217    def split_qpoints(self):
1218        try:
1219            return self._split_qpoints
1220        except AttributeError:
1221            self._set_split_intervals()
1222            return self._split_qpoints
1223
1224    @property
1225    def split_phfreqs(self):
1226        try:
1227            return self._split_phfreqs
1228        except AttributeError:
1229            self._set_split_intervals()
1230            return self._split_phfreqs
1231
1232    @property
1233    def split_phdispl_cart(self):
1234        # prepare the splitted phdispl_cart as a separate internal variable only when explicitely requested and
1235        # not at the same time as split_qpoints and split_phfreqs as it requires a larger array and not used
1236        # most of the times.
1237        try:
1238            return self._split_phdispl_cart
1239        except AttributeError:
1240            self.split_phfreqs
1241            split_phdispl_cart = [np.array(self.phdispl_cart[self._split_indices[i]:self._split_indices[i + 1] + 1])
1242                                  for i in range(len(self._split_indices) - 1)]
1243            if self.non_anal_ph is not None:
1244                for i, q in enumerate(self.split_qpoints):
1245                    if np.array_equal(q[0], (0, 0, 0)):
1246                        if self.non_anal_ph.has_direction(q[1]):
1247                            split_phdispl_cart[i][0, :] = self._get_non_anal_phdispl(q[1])
1248                    if np.array_equal(q[-1], (0, 0, 0)):
1249                        if self.non_anal_ph.has_direction(q[-2]):
1250                            split_phdispl_cart[i][-1, :] = self._get_non_anal_phdispl(q[-2])
1251
1252            self._split_phdispl_cart = split_phdispl_cart
1253            return self._split_phdispl_cart
1254
1255    def _set_split_intervals(self):
1256        # Calculations available for LO-TO splitting
1257        # Split the lines at each Gamma to handle possible discontinuities
1258        if self.non_anal_phfreqs is not None and self.non_anal_directions is not None:
1259            end_points_indices = [0]
1260
1261            end_points_indices.extend(
1262                [i for i in range(1, self.num_qpoints - 1) if np.array_equal(self.qpoints.frac_coords[i], [0, 0, 0])])
1263            end_points_indices.append(self.num_qpoints - 1)
1264
1265            # split the list of qpoints and frequencies at each end point. The end points are in both the segments.
1266            # Lists since the array contained have different shapes
1267            split_qpoints = [np.array(self.qpoints.frac_coords[end_points_indices[i]:end_points_indices[i + 1] + 1])
1268                             for i in range(len(end_points_indices) - 1)]
1269            split_phfreqs = [np.array(self.phfreqs[end_points_indices[i]:end_points_indices[i + 1] + 1])
1270                             for i in range(len(end_points_indices) - 1)]
1271
1272            for i, q in enumerate(split_qpoints):
1273                if np.array_equal(q[0], (0, 0, 0)):
1274                    split_phfreqs[i][0, :] = self._get_non_anal_freqs(q[1])
1275                if np.array_equal(q[-1], (0, 0, 0)):
1276                    split_phfreqs[i][-1, :] = self._get_non_anal_freqs(q[-2])
1277        else:
1278            split_qpoints = [self.qpoints.frac_coords]
1279            split_phfreqs = [self.phfreqs]
1280            end_points_indices = [0, self.num_qpoints-1]
1281
1282        self._split_qpoints = split_qpoints
1283        self._split_phfreqs = split_phfreqs
1284        self._split_indices = end_points_indices
1285        return split_phfreqs, split_qpoints
1286
1287    @property
1288    def split_matched_indices(self):
1289        """
1290        A list of numpy arrays containing the indices in which each band should be sorted in order to match the
1291        scalar product of the eigenvectors. The shape is the same as that of split_phfreqs.
1292        Lazy property.
1293        """
1294        try:
1295            return self._split_matched_indices
1296        except AttributeError:
1297
1298            split_matched_indices = []
1299            last_eigenvectors = None
1300
1301            # simpler method based just on the matching with the previous point
1302            #TODO remove after verifying the other method currently in use
1303            # for i, displ in enumerate(self.split_phdispl_cart):
1304            #     eigenvectors = get_dyn_mat_eigenvec(displ, self.structure, amu=self.amu)
1305            #     ind_block = np.zeros((len(displ), self.num_branches), dtype=int)
1306            #     # if it's not the first block, match with the last of the previous block. Should give a match in case
1307            #     # of LO-TO splitting
1308            #     if i == 0:
1309            #         ind_block[0] = range(self.num_branches)
1310            #     else:
1311            #         match = match_eigenvectors(last_eigenvectors, eigenvectors[0])
1312            #         ind_block[0] = [match[m] for m in split_matched_indices[-1][-1]]
1313            #     for j in range(1, len(displ)):
1314            #         match = match_eigenvectors(eigenvectors[j-1], eigenvectors[j])
1315            #         ind_block[j] = [match[m] for m in ind_block[j-1]]
1316            #
1317            #     split_matched_indices.append(ind_block)
1318            #     last_eigenvectors = eigenvectors[-1]
1319
1320            # The match is applied between subsequent qpoints, except that right after a high symmetry point.
1321            # In that case the first point after the high symmetry point will be matched with the one immediately
1322            # before. This should avoid exchange of lines due to degeneracies.
1323            # The code will assume that there is a high symmetry point if the points are not collinear (change in the
1324            # direction in the path).
1325            def collinear(a, b, c):
1326                v1 = [b[0] - a[0], b[1] - a[1], b[2] - a[2]]
1327                v2 = [c[0] - a[0], c[1] - a[1], c[2] - a[2]]
1328                d = [v1, v2, [1, 1, 1]]
1329                return np.isclose(np.linalg.det(d), 0, atol=1e-5)
1330
1331            for i, displ in enumerate(self.split_phdispl_cart):
1332                eigenvectors = get_dyn_mat_eigenvec(displ, self.structure, amu=self.amu)
1333                ind_block = np.zeros((len(displ), self.num_branches), dtype=int)
1334                # if it's not the first block, match the first two points with the last of the previous block.
1335                # Should give a match in case of LO-TO splitting
1336                if i == 0:
1337                    ind_block[0] = range(self.num_branches)
1338                    match = match_eigenvectors(eigenvectors[0], eigenvectors[1])
1339                    ind_block[1] = [match[m] for m in ind_block[0]]
1340                else:
1341                    match = match_eigenvectors(last_eigenvectors, eigenvectors[0])
1342                    ind_block[0] = [match[m] for m in split_matched_indices[-1][-2]]
1343                    match = match_eigenvectors(last_eigenvectors, eigenvectors[1])
1344                    ind_block[1] = [match[m] for m in split_matched_indices[-1][-2]]
1345                for j in range(2, len(displ)):
1346                    k = j-1
1347                    if not collinear(self.split_qpoints[i][j-2], self.split_qpoints[i][j-1], self.split_qpoints[i][j]):
1348                        k = j-2
1349                    match = match_eigenvectors(eigenvectors[k], eigenvectors[j])
1350                    ind_block[j] = [match[m] for m in ind_block[k]]
1351
1352                split_matched_indices.append(ind_block)
1353                last_eigenvectors = eigenvectors[-2]
1354
1355            self._split_matched_indices = split_matched_indices
1356
1357            return self._split_matched_indices
1358
1359    def _get_non_anal_freqs(self, frac_direction):
1360        # directions for the qph2l in anaddb are given in cartesian coordinates
1361        cart_direction = self.structure.lattice.reciprocal_lattice_crystallographic.get_cartesian_coords(frac_direction)
1362        cart_direction = cart_direction / np.linalg.norm(cart_direction)
1363
1364        for i, d in enumerate(self.non_anal_directions):
1365            d = d / np.linalg.norm(d)
1366            if np.allclose(cart_direction, d):
1367                return self.non_anal_phfreqs[i]
1368
1369        raise ValueError("Non analytical contribution has not been calculated for reduced direction {0} ".format(frac_direction))
1370
1371    def _get_non_anal_phdispl(self, frac_direction):
1372        # directions for the qph2l in anaddb are given in cartesian coordinates
1373        cart_direction = self.structure.lattice.reciprocal_lattice_crystallographic.get_cartesian_coords(frac_direction)
1374        cart_direction = cart_direction / np.linalg.norm(cart_direction)
1375
1376        for i, d in enumerate(self.non_anal_directions):
1377            d = d / np.linalg.norm(d)
1378            if np.allclose(cart_direction, d):
1379                return self.non_anal_phdispl_cart[i]
1380
1381        raise ValueError("Non analytical contribution has not been calcolated for reduced direction {0} ".format(frac_direction))
1382
1383    def _make_ticks_and_labels(self, qlabels):
1384        """Return ticks and labels from the mapping {qred: qstring} given in qlabels."""
1385        #TODO should be modified in order to handle the "split" list of qpoints
1386        if qlabels is not None:
1387            d = OrderedDict()
1388
1389            for qcoord, qname in qlabels.items():
1390                # Build Kpoint instancee
1391                qtick = Kpoint(qcoord, self.structure.reciprocal_lattice)
1392                for q, qpoint in enumerate(self.qpoints):
1393                    if qtick == qpoint:
1394                        d[q] = qname
1395        else:
1396            d = self._auto_qlabels
1397
1398        # Return ticks, labels
1399        return list(d.keys()), list(d.values())
1400
1401    # TODO: fatbands along x, y, z
1402    @add_fig_kwargs
1403    def plot_fatbands(self, use_eigvec=True, units="eV", colormap="jet", phdos_file=None,
1404                      alpha=0.6, max_stripe_width_mev=5.0, width_ratios=(2, 1),
1405                      qlabels=None, ylims=None, fontsize=12, **kwargs):
1406        r"""
1407        Plot phonon fatbands and, optionally, atom-projected phonon DOSes.
1408        The width of the band is given by ||v_{type}||
1409        where v is the (complex) phonon displacement (eigenvector) in cartesian coordinates and
1410        v_{type} selects only the terms associated to the atomic type.
1411
1412        Args:
1413            use_eigvec: True if the width of the phonon branch should be computed from the eigenvectors.
1414                False to use phonon displacements. Note that the PHDOS is always decomposed in
1415                terms of (orthonormal) eigenvectors.
1416            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
1417                Case-insensitive.
1418            colormap: Have a look at the colormaps here and decide which one you like:
1419                http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html
1420            phdos_file: Used to activate fatbands + PJDOS plot.
1421                Accept string with path of PHDOS.nc file or |PhdosFile| object.
1422            alpha: The alpha blending value, between 0 (transparent) and 1 (opaque)
1423            max_stripe_width_mev: The maximum width of the stripe in meV. Will be rescaled according to ``units``.
1424            width_ratios: Ratio between the width of the fatbands plots and the DOS plots.
1425                Used if `phdos_file` is not None
1426            ylims: Set the data limits for the y-axis. Accept tuple e.g. `(left, right)`
1427                   or scalar e.g. `left`. If left (right) is None, default values are used
1428            qlabels: dictionary whose keys are tuples with the reduced coordinates of the q-points.
1429                The values are the labels. e.g. ``qlabels = {(0.0,0.0,0.0): "$\Gamma$", (0.5,0,0): "L"}``.
1430            fontsize: Legend and title fontsize.
1431
1432        Returns: |matplotlib-Figure|
1433        """
1434        lw = kwargs.pop("lw", 2)
1435        factor = abu.phfactor_ev2units(units)
1436        ntypat = self.structure.ntypesp
1437
1438        # Prepare PJDOS.
1439        close_phdos_file = False
1440        if phdos_file is not None:
1441            if is_string(phdos_file):
1442                phdos_file = PhdosFile(phdos_file)
1443                close_phdos_file = True
1444            else:
1445                if not isinstance(phdos_file, PhdosFile):
1446                    raise TypeError("Expecting string or PhdosFile, got %s" % type(phdos_file))
1447
1448        # Grid with [ntypat] plots if fatbands only or [ntypat, 2] if fatbands + PJDOS
1449        import matplotlib.pyplot as plt
1450        from matplotlib.gridspec import GridSpec
1451
1452        fig = plt.figure()
1453        nrows, ncols = (ntypat, 1) if phdos_file is None else (ntypat, 2)
1454        gspec = GridSpec(nrows=nrows, ncols=ncols, width_ratios=width_ratios if ncols == 2 else None,
1455                         wspace=0.05, hspace=0.1)
1456
1457        cmap = plt.get_cmap(colormap)
1458        qq = list(range(self.num_qpoints))
1459
1460        # phonon_displacements are in cartesian coordinates and stored in an array with shape
1461        # (nqpt, 3*natom, 3*natom) where the last dimension stores the cartesian components.
1462        # PJDoses are in cartesian coordinates and are computed by anaddb using the the
1463        # phonon eigenvectors that are orthonormal.
1464
1465        # Precompute normalization factor:
1466        # Here I use d2[q, nu] = \sum_{i=0}^{3*Nat-1) |d^{q\nu}_i|**2
1467        # it makes sense only for displacements
1468        d2_qnu = np.ones((self.num_qpoints, self.num_branches))
1469        if not use_eigvec:
1470            for iq in range(self.num_qpoints):
1471                for nu in self.branches:
1472                    cvect = self.phdispl_cart[iq, nu, :]
1473                    d2_qnu[iq, nu] = np.vdot(cvect, cvect).real
1474
1475        # Plot fatbands: one plot per atom type.
1476        ax00 = None
1477        for ax_row, symbol in enumerate(self.structure.symbol_set):
1478            last_ax = (ax_row == len(self.structure.symbol_set) - 1)
1479            ax = plt.subplot(gspec[ax_row, 0], sharex=ax00, sharey=ax00)
1480            if ax_row == 0: ax00 = ax
1481            self.decorate_ax(ax, units=units, qlabels=qlabels)
1482            color = cmap(float(ax_row) / max(1, ntypat - 1))
1483
1484            # dir_indices lists the coordinate indices for the atoms of the same type.
1485            atom_indices = self.structure.indices_from_symbol(symbol)
1486            dir_indices = []
1487            for aindx in atom_indices:
1488                start = 3 * aindx
1489                dir_indices.extend([start, start + 1, start + 2])
1490            dir_indices = np.array(dir_indices)
1491
1492            for nu in self.branches:
1493                yy_qq = self.phfreqs[:, nu] * factor
1494
1495                # Exctract the sub-vector associated to this atom type (eigvec or diplacement).
1496                if use_eigvec:
1497                    v_type = self.dyn_mat_eigenvect[:, nu, dir_indices]
1498                else:
1499                    v_type = self.phdispl_cart[:, nu, dir_indices]
1500
1501                v2_type = np.empty(self.num_qpoints)
1502                for iq in range(self.num_qpoints):
1503                    v2_type[iq] = np.vdot(v_type[iq], v_type[iq]).real
1504
1505                # Normalize and scale by max_stripe_width_mev taking into account units.
1506                # The stripe is centered on the phonon branch hence the factor 2
1507                stype_qq = (factor * max_stripe_width_mev * 1.e-3 / 2) * np.sqrt(v2_type / d2_qnu[:, nu])
1508
1509                # Plot the phonon branch with the stripe.
1510                if nu == 0:
1511                    ax.plot(qq, yy_qq, lw=lw, color=color, label=symbol)
1512                else:
1513                    ax.plot(qq, yy_qq, lw=lw, color=color)
1514
1515                ax.fill_between(qq, yy_qq + stype_qq, yy_qq - stype_qq, facecolor=color, alpha=alpha, linewidth=0)
1516
1517            set_axlims(ax, ylims, "y")
1518            ax.legend(loc="best", fontsize=fontsize, shadow=True)
1519
1520        # Type projected DOSes (always computed from eigenvectors in anaddb).
1521        if phdos_file is not None:
1522            ax01 = None
1523            for ax_row, symbol in enumerate(self.structure.symbol_set):
1524                color = cmap(float(ax_row) / max(1, ntypat - 1))
1525                ax = plt.subplot(gspec[ax_row, 1], sharex=ax01, sharey=ax00)
1526                if ax_row == 0: ax01 = ax
1527
1528                # Get PJDOS: Dictionary symbol --> partial PhononDos
1529                pjdos = phdos_file.pjdos_symbol[symbol]
1530                x, y = pjdos.mesh * factor, pjdos.values / factor
1531
1532                ax.plot(y, x, lw=lw, color=color)
1533                ax.grid(True)
1534                ax.yaxis.set_ticks_position("right")
1535                ax.yaxis.set_label_position("right")
1536                set_axlims(ax, ylims, "y")
1537
1538            if close_phdos_file:
1539                phdos_file.close()
1540
1541        return fig
1542
1543    @add_fig_kwargs
1544    def plot_with_phdos(self, phdos, units="eV", qlabels=None, ax_list=None, width_ratios=(2, 1), **kwargs):
1545        r"""
1546        Plot the phonon band structure with the phonon DOS.
1547
1548        Args:
1549            phdos: An instance of |PhononDos| or a netcdf file providing a PhononDos object.
1550            units: Units for plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
1551                Case-insensitive.
1552            qlabels: dictionary whose keys are tuples with the reduced coordinates of the q-points.
1553                The values are the labels e.g. ``qlabels = {(0.0,0.0,0.0): "$\Gamma$", (0.5,0,0): "L"}``.
1554            ax_list: The axes for the bandstructure plot and the DOS plot. If ax_list is None, a new figure
1555                is created and the two axes are automatically generated.
1556            width_ratios: Ratio between the width of the bands plots and the DOS plots.
1557                Used if ``ax_list`` is None
1558
1559        Returns: |matplotlib-Figure|
1560        """
1561        phdos = PhononDos.as_phdos(phdos, phdos_kwargs=None)
1562
1563        import matplotlib.pyplot as plt
1564        if ax_list is None:
1565            # Build axes and align bands and DOS.
1566            from matplotlib.gridspec import GridSpec
1567            fig = plt.figure()
1568            gspec = GridSpec(1, 2, width_ratios=width_ratios, wspace=0.05)
1569            ax1 = plt.subplot(gspec[0])
1570            ax2 = plt.subplot(gspec[1], sharey=ax1)
1571        else:
1572            # Take them from ax_list.
1573            ax1, ax2 = ax_list
1574            fig = plt.gcf()
1575
1576        if not kwargs:
1577            kwargs = {"color": "black", "linewidth": 2.0}
1578
1579        # Plot the phonon band structure.
1580        self.plot_ax(ax1, branch=None, units=units, **kwargs)
1581        self.decorate_ax(ax1, units=units, qlabels=qlabels)
1582
1583        factor = abu.phfactor_ev2units(units)
1584        emin = np.min(self.minfreq)
1585        emin -= 0.05 * abs(emin)
1586        emin *= factor
1587        emax = np.max(self.maxfreq)
1588        emax += 0.05 * abs(emax)
1589        emax *= factor
1590        ax1.yaxis.set_view_interval(emin, emax)
1591
1592        # Plot Phonon DOS
1593        phdos.plot_dos_idos(ax2, what="d", units=units, exchange_xy=True, **kwargs)
1594
1595        ax2.grid(True)
1596        ax2.yaxis.set_ticks_position("right")
1597        #ax2.yaxis.set_label_position("right")
1598
1599        return fig
1600
1601    @add_fig_kwargs
1602    def plot_phdispl(self, qpoint, cart_dir=None, use_reduced_coords=False, ax=None, units="eV",
1603                     is_non_analytical_direction=False, use_eigvec=False,
1604                     colormap="viridis", hatches="default", atoms_index=None, labels_groups=None,
1605                     normalize=True, use_sqrt=False, fontsize=12, branches=None, format_w="%.3f", **kwargs):
1606        """
1607        Plot vertical bars with the contribution of the different atoms or atomic types to all the phonon modes
1608        at a given ``qpoint``. The contribution is given by ||v_{type}||
1609        where v is the (complex) phonon displacement (eigenvector) in cartesian coordinates and
1610        v_{type} selects only the terms associated to the atomic type.
1611        Options allow to specify which atoms should be taken into account and how should be reparted.
1612
1613        Args:
1614            qpoint: integer, vector of reduced coordinates or |Kpoint| object.
1615            cart_dir: "x", "y", or "z" to select a particular Cartesian directions. or combinations separated by "+".
1616                Example: "x+y". None if no projection is wanted.
1617            ax: |matplotlib-Axes| or None if a new figure should be created.
1618            units: Units for phonon frequencies. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
1619                Case-insensitive.
1620            is_non_analytical_direction: If True, the ``qpoint`` is interpreted as a direction in q-space
1621                and the phonon (displacements/eigenvectors) for q --> 0 along this direction are used.
1622                Requires band structure with :class:`NonAnalyticalPh` object.
1623            use_eigvec: True if eigenvectors should be used instead of displacements (eigenvectors
1624                are orthonormal, unlike diplacements)
1625            colormap: Matplotlib colormap used for atom type.
1626            hatches: List of strings with matplotlib hatching patterns. None or empty list to disable hatching.
1627            fontsize: Legend and title fontsize.
1628            normalize: if True divides by the square norm of the total eigenvector
1629            use_sqrt: if True the square root of the sum of the components will be taken
1630            use_reduced_coords: if True coordinates will be converted to reduced coordinates. So the values will be
1631                fraction of a,b,c rather than x,y,z.
1632            atoms_index: list of lists. Each list contains the indices of atoms in the structure that will be
1633                summed on a separate group. if None all the atoms will be considered and grouped by type.
1634            labels_groups: If atoms_index is not None will provide the labels for each of the group in atoms_index.
1635                Should have the same length of atoms_index or be None. If None automatic labelling will be used.
1636            branches: list of indices for the modes that should be represented. If None all the modes will be shown.
1637            format_w: string used to format the values of the frequency. Default "%.3f".
1638
1639        Returns: |matplotlib-Figure|
1640        """
1641        factor = abu.phfactor_ev2units(units)
1642
1643        dxyz = {"x": 0, "y": 1, "z": 2, None: None}
1644
1645        if cart_dir is None:
1646            icart = None
1647        else:
1648            icart = [dxyz[c] for c in cart_dir.split("+")]
1649
1650        iq, qpoint = self.qindex_qpoint(qpoint, is_non_analytical_direction=is_non_analytical_direction)
1651
1652        if use_sqrt:
1653            f_sqrt = np.sqrt
1654        else:
1655            f_sqrt = lambda x: x
1656
1657        if branches is None:
1658            branches = self.branches
1659        elif not isinstance(branches, (list, tuple)):
1660            branches = [branches]
1661
1662        ax, fig, plt = get_ax_fig_plt(ax=ax)
1663        cmap = plt.get_cmap(colormap)
1664        ntypat = self.structure.ntypesp
1665
1666        if is_non_analytical_direction:
1667            ax.set_title("q-direction = %s" % repr(qpoint), fontsize=fontsize)
1668        else:
1669            ax.set_title("qpoint = %s" % repr(qpoint), fontsize=fontsize)
1670        ax.set_xlabel('Frequency %s' % abu.phunit_tag(units))
1671
1672        what = r"\epsilon" if use_eigvec else "d"
1673        if icart is None:
1674            ax.set_ylabel(r"${|\vec{%s}_{type}|} (stacked)$" % what, fontsize=fontsize)
1675        else:
1676            ax.set_ylabel(r"${|\vec{%s}_{%s,type}|} (stacked)$" % (what, cart_dir), fontsize=fontsize)
1677
1678        symbol2indices = self.structure.get_symbol2indices()
1679
1680        width, pad = 4, 1
1681        pad = width + pad
1682        xticks, xticklabels = [], []
1683        if hatches == "default":
1684            hatches = ["/", "\\", "'", "|", "-", "+", "x", "o", "O", ".", "*"]
1685        else:
1686            hatches = list_strings(hatches) if hatches is not None else []
1687
1688        x = 0
1689        for inu, nu in enumerate(branches):
1690            # Select frequencies and cartesian displacements/eigenvectors
1691            if is_non_analytical_direction:
1692                w_qnu = self.non_anal_phfreqs[iq, nu] * factor
1693                if use_eigvec:
1694                    vcart_qnu = np.reshape(self.non_anal_ph.dyn_mat_eigenvect[iq, nu], (len(self.structure), 3))
1695                else:
1696                    vcart_qnu = np.reshape(self.non_anal_phdispl_cart[iq, nu], (len(self.structure), 3))
1697            else:
1698                w_qnu = self.phfreqs[iq, nu] * factor
1699                if use_eigvec:
1700                    vcart_qnu = np.reshape(self.dyn_mat_eigenvect[iq, nu], (len(self.structure), 3))
1701                else:
1702                    vcart_qnu = np.reshape(self.phdispl_cart[iq, nu], (len(self.structure), 3))
1703
1704            if use_reduced_coords:
1705                vcart_qnu = np.dot(vcart_qnu, self.structure.lattice.inv_matrix)
1706
1707            if normalize:
1708                vnorm2 = f_sqrt(sum(np.linalg.norm(d) ** 2 for d in vcart_qnu))
1709            else:
1710                vnorm2 = 1.0
1711
1712            # Make a bar plot with rectangles bounded by (x - width/2, x + width/2, bottom, bottom + height)
1713            # The align keyword controls if x is interpreted as the center or the left edge of the rectangle.
1714            bottom, height = 0.0, 0.0
1715            if atoms_index is None:
1716                for itype, (symbol, inds) in enumerate(symbol2indices.items()):
1717                    if icart is None:
1718                        height = f_sqrt(sum(np.linalg.norm(d) ** 2 for d in vcart_qnu[inds]) / vnorm2)
1719                    else:
1720                        height = f_sqrt(
1721                            sum(np.linalg.norm(d) ** 2 for ic in icart for d in vcart_qnu[inds, ic]) / vnorm2)
1722
1723                    ax.bar(x, height, width, bottom, align="center",
1724                           color=cmap(float(itype) / max(1, ntypat - 1)),
1725                           label=symbol if inu == 0 else None, edgecolor='black',
1726                           hatch=hatches[itype % len(hatches)] if hatches else None,
1727                           )
1728                    bottom += height
1729            else:
1730                for igroup, inds in enumerate(atoms_index):
1731                    inds = np.array(inds)
1732
1733                    if labels_groups:
1734                        symbol = labels_groups[igroup]
1735                    else:
1736                        symbol = "+".join("{}{}".format(self.structure[ia].specie.name, ia) for ia in inds)
1737
1738                    if icart is None:
1739                        height = f_sqrt(sum(np.linalg.norm(d) ** 2 for d in vcart_qnu[inds]) / vnorm2)
1740                    else:
1741                        height = f_sqrt(
1742                            sum(np.linalg.norm(d) ** 2 for ic in icart for d in vcart_qnu[inds, ic]) / vnorm2)
1743
1744                    ax.bar(x, height, width, bottom, align="center",
1745                           color=cmap(float(igroup) / max(1, len(atoms_index) - 1)),
1746                           label=symbol if inu == 0 else None, edgecolor='black',
1747                           hatch=hatches[igroup % len(hatches)] if hatches else None,
1748                           )
1749                    bottom += height
1750
1751            xticks.append(x)
1752            xticklabels.append(format_w % w_qnu)
1753            x += (width + pad) / 2
1754
1755        ax.set_xticks(xticks)
1756        ax.set_xticklabels((xticklabels))
1757        ax.legend(loc="best", fontsize=fontsize, shadow=True)
1758
1759        return fig
1760
1761    @add_fig_kwargs
1762    def plot_phdispl_cartdirs(self, qpoint, cart_dirs=("x", "y", "z"), units="eV",
1763                              is_non_analytical_direction=False, use_eigvec=False,
1764                              colormap="viridis", hatches="default", atoms_index=None, labels_groups=None,
1765                              normalize=True, use_sqrt=False, fontsize=8, branches=None, format_w="%.3f", **kwargs):
1766        """
1767        Plot three panels. Each panel shows vertical bars with the contribution of the different atomic types
1768        to all the phonon displacements at the given ``qpoint`` along on the Cartesian directions in ``cart_dirs``.
1769
1770        Args:
1771            qpoint: integer, vector of reduced coordinates or |Kpoint| object.
1772            cart_dirs: List of strings defining the Cartesian directions. "x", "y", or "z" to select a particular
1773                Cartesian directions. or combinations separated by "+". Example: "x+y".
1774            units: Units for phonon frequencies. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
1775                Case-insensitive.
1776            is_non_analytical_direction: If True, the ``qpoint`` is interpreted as a direction in q-space
1777                and the phonon (displacements/eigenvectors) for q --> 0 along this direction are used.
1778                Requires band structure with :class:`NonAnalyticalPh` object.
1779            use_eigvec: True if eigenvectors should be used instead of displacements (eigenvectors
1780                are orthonormal, unlike diplacements)
1781            colormap: Matplotlib colormap used for atom type.
1782            hatches: List of strings with matplotlib hatching patterns. None or empty list to disable hatching.
1783            fontsize: Legend and title fontsize.
1784            normalize: if True divides by the square norm of the total eigenvector
1785            use_sqrt: if True the square root of the sum of the components will be taken
1786                fraction of a,b,c rather than x,y,z.
1787            atoms_index: list of lists. Each list contains the indices of atoms in the structure that will be
1788                summed on a separate group. if None all the atoms will be considered and grouped by type.
1789            labels_groups: If atoms_index is not None will provide the labels for each of the group in atoms_index.
1790                Should have the same length of atoms_index or be None. If None automatic labelling will be used.
1791            branches: list of indices for the modes that should be represented. If None all the modes will be shown.
1792            format_w: string used to format the values of the frequency. Default "%.3f".
1793
1794        See plot_phdispl for the meaning of the other arguments.
1795        """
1796        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=len(cart_dirs), ncols=1,
1797                                                sharex=True, sharey=True, squeeze=False)
1798
1799        for i, (cart_dir, ax) in enumerate(zip(cart_dirs, ax_list.ravel())):
1800            self.plot_phdispl(qpoint, cart_dir=cart_dir, ax=ax, units=units, colormap=colormap,
1801                              is_non_analytical_direction=is_non_analytical_direction, use_eigvec=use_eigvec,
1802                              fontsize=fontsize, hatches=hatches, atoms_index=atoms_index, labels_groups=labels_groups,
1803                              normalize=normalize, use_sqrt=use_sqrt, branches=branches, show=False, format_w=format_w)
1804            # Disable artists.
1805            if i != 0:
1806                set_visible(ax, False, "legend", "title")
1807            #if len(cart_dirs) == 3 and i != 1:
1808            #    set_visible(ax, False, "ylabel")
1809            if i != len(cart_dirs) - 1:
1810                set_visible(ax, False, "xlabel")
1811
1812        return fig
1813
1814    def get_dataframe(self):
1815        """
1816        Return a |pandas-DataFrame| with the following columns:
1817
1818            ['qidx', 'mode', 'freq', 'qpoint']
1819
1820        where:
1821
1822        ==============  ==========================
1823        Column          Meaning
1824        ==============  ==========================
1825        qidx            q-point index.
1826        mode            phonon branch index.
1827        freq            Phonon frequency in eV.
1828        qpoint          |Kpoint| object
1829        ==============  ==========================
1830        """
1831        import pandas as pd
1832        rows = []
1833        for iq, qpoint in enumerate(self.qpoints):
1834            for nu in self.branches:
1835                rows.append(OrderedDict([
1836                           ("qidx", iq),
1837                           ("mode", nu),
1838                           ("freq", self.phfreqs[iq, nu]),
1839                           ("qpoint", self.qpoints[iq]),
1840                        ]))
1841
1842        return pd.DataFrame(rows, columns=list(rows[0].keys()))
1843
1844    @add_fig_kwargs
1845    def boxplot(self, ax=None, units="eV", mode_range=None, swarm=False, **kwargs):
1846        """
1847        Use seaborn_ to draw a box plot to show distributions of eigenvalues with respect to the mode index.
1848
1849        Args:
1850            ax: |matplotlib-Axes| or None if a new figure should be created.
1851            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
1852                Case-insensitive.
1853            mode_range: Only modes such as `mode_range[0] <= mode_index < mode_range[1]` are included in the plot.
1854            swarm: True to show the datapoints on top of the boxes
1855            kwargs: Keyword arguments passed to seaborn boxplot.
1856        """
1857        # Get the dataframe and select bands
1858        frame = self.get_dataframe()
1859        if mode_range is not None:
1860            frame = frame[(frame["mode"] >= mode_range[0]) & (frame["mode"] < mode_range[1])]
1861
1862        ax, fig, plt = get_ax_fig_plt(ax=ax)
1863        ax.grid(True)
1864
1865        factor = abu.phfactor_ev2units(units)
1866        yname = "freq %s" % abu.phunit_tag(units)
1867        frame[yname] = factor * frame["freq"]
1868
1869        import seaborn as sns
1870        hue = None
1871        ax = sns.boxplot(x="mode", y=yname, data=frame, hue=hue, ax=ax, **kwargs)
1872        if swarm:
1873            sns.swarmplot(x="mode", y=yname, data=frame, hue=hue, color=".25", ax=ax)
1874
1875        return fig
1876
1877    def to_pymatgen(self, qlabels=None):
1878        r"""
1879        Creates a pymatgen :class:`PhononBandStructureSymmLine` object.
1880
1881        Args:
1882            qlabels: dictionary whose keys are tuples with the reduced coordinates of the q-points.
1883                The values are the labels e.g. ``qlabels = {(0.0,0.0,0.0):"$\Gamma$", (0.5,0,0):"L"}``.
1884                If None labels will be determined automatically.
1885        """
1886        # pymatgen labels dict is inverted
1887        if qlabels is None:
1888            qlabels = self._auto_qlabels
1889            # the indices in qlabels are without the split
1890            labels_dict = {v: self.qpoints[k].frac_coords for k, v in qlabels.items()}
1891        else:
1892            labels_dict = {v: k for k, v in qlabels.items()}
1893
1894        labelled_q_list = list(labels_dict.values())
1895
1896        ph_freqs, qpts, displ = [], [], []
1897        for split_q, split_phf, split_phdispl in zip(self.split_qpoints, self.split_phfreqs, self.split_phdispl_cart):
1898            # if the qpoint has a label it needs to be repeated. If it is one of the extrema either it should
1899            # not be repeated (if they are the real first or last point) or they will be already repeated due
1900            # to the split. Also they should not be repeated in case there are two consecutive labelled points.
1901            # So first determine which ones have a label.
1902            labelled = [any(np.allclose(q, labelled_q) for labelled_q in labelled_q_list) for q in split_q]
1903
1904            for i, (q, phf, d, l) in enumerate(zip(split_q, split_phf, split_phdispl, labelled)):
1905                ph_freqs.append(phf)
1906                qpts.append(q)
1907                d = d.reshape(self.num_branches, self.num_atoms, 3)
1908                displ.append(d)
1909
1910                if 0 < i < len(split_q) - 1 and l and not labelled[i-1] and not labelled[i+1]:
1911                    ph_freqs.append(phf)
1912                    qpts.append(q)
1913                    displ.append(d)
1914
1915        ph_freqs = np.transpose(ph_freqs) * abu.eV_to_THz
1916        qpts = np.array(qpts)
1917        displ = np.transpose(displ, (1, 0, 2, 3))
1918
1919        return PhononBandStructureSymmLine(qpoints=qpts, frequencies=ph_freqs,
1920                                           lattice=self.structure.reciprocal_lattice,
1921                                           has_nac=self.non_anal_ph is not None, eigendisplacements=displ,
1922                                           labels_dict=labels_dict, structure=self.structure)
1923
1924    @classmethod
1925    def from_pmg_bs(cls, pmg_bs, structure=None):
1926        """
1927        Creates an instance of the object from a :class:`PhononBandStructureSymmLine` object.
1928
1929        Args:
1930            pmg_bs: the instance of PhononBandStructureSymmLine.
1931            structure: a |Structure| object. Should be present if the structure attribute is
1932                not set in pmg_bs.
1933        """
1934
1935        structure = structure or pmg_bs.structure
1936        if not structure:
1937            raise ValueError("The structure is needed to create the abipy object.")
1938
1939        structure = Structure.from_sites(structure)
1940        structure.spgset_abi_spacegroup(has_timerev=False)
1941
1942        qpoints = []
1943        phfreqs = []
1944        phdispl_cart = []
1945        names = []
1946
1947        prev_q = None
1948        for b in pmg_bs.branches:
1949            qname1, qname2 = b["name"].split("-")
1950            start_index = b["start_index"]
1951            if prev_q is not None and qname1 == prev_q:
1952                start_index += 1
1953
1954            # it can happen depending on how the object was generated
1955            if start_index >= b["end_index"]:
1956                continue
1957
1958            prev_q = qname2
1959
1960            if start_index == b["start_index"]:
1961                names.append(qname1)
1962
1963            names.extend([None] * (b["end_index"] - b["start_index"] - 1))
1964            names.append(qname2)
1965
1966            for i in range(start_index, b["end_index"] + 1):
1967                qpoints.append(pmg_bs.qpoints[i].frac_coords)
1968            phfreqs.extend(pmg_bs.bands.T[start_index:b["end_index"] + 1])
1969            if pmg_bs.has_eigendisplacements:
1970                e = pmg_bs.eigendisplacements[:, start_index:b["end_index"] + 1]
1971                e = np.transpose(e, [0, 1, 2, 3])
1972                e = np.reshape(e, e.shape[:-2] + (-1,))
1973                phdispl_cart.extend(e)
1974
1975        #print(len(names), len(phfreqs))
1976        qpoints_list = KpointList(reciprocal_lattice=structure.reciprocal_lattice,
1977                                  frac_coords=qpoints, names=names)
1978
1979        phfreqs = np.array(phfreqs) / abu.eV_to_THz
1980        n_modes = 3 * len(structure)
1981        if not phdispl_cart:
1982            phdispl_cart = np.zeros((len(phfreqs), n_modes, n_modes))
1983        else:
1984            phdispl_cart = np.array(phdispl_cart)
1985
1986        na = None
1987        if pmg_bs.has_nac:
1988            directions = []
1989            nac_phreqs = []
1990            nac_phdispl = []
1991
1992            for t in pmg_bs.nac_frequencies:
1993                # directions in NonAnalyticalPh are given in cartesian coordinates
1994                cart_direction = structure.lattice.reciprocal_lattice_crystallographic.get_cartesian_coords(t[0])
1995                cart_direction = cart_direction / np.linalg.norm(cart_direction)
1996
1997                directions.append(cart_direction)
1998                nac_phreqs.append(t[1])
1999
2000            nac_phreqs = np.array(nac_phreqs) / abu.eV_to_THz
2001
2002            for t in pmg_bs.nac_eigendisplacements:
2003                nac_phdispl.append(t[1].reshape(n_modes, n_modes))
2004
2005            na = NonAnalyticalPh(structure=structure, directions=np.array(directions),
2006                                 phfreqs=nac_phreqs, phdispl_cart=np.array(nac_phdispl))
2007
2008        phb = cls(structure=structure, qpoints=qpoints_list, phfreqs=phfreqs, phdispl_cart=phdispl_cart,
2009                  non_anal_ph=na)
2010
2011        return phb
2012
2013    def acoustic_indices(self, qpoint, threshold=0.95, raise_on_no_indices=True):
2014        """
2015        Extract the indices of the three acoustic modes for a qpoint.
2016        Acoustic modes could be reasonably identified for Gamma and points close to Gamma.
2017
2018        Args:
2019            qpoint: the qpoint. Accepts integer or reduced coordinates
2020            threshold: fractional value allowed for the matching of the displacements to identify acoustic modes.
2021            raise_on_no_indices: if True a RuntimeError will be raised if the acoustic mode will not be
2022                correctly identified. If False [0, 1, 2] will be returned.
2023        """
2024        qindex = self.qindex(qpoint)
2025        phdispl = self.phdispl_cart[qindex]
2026
2027        indices = []
2028        for mode, displ_mode in enumerate(phdispl):
2029            displ_mode = np.reshape(displ_mode, (-1, 3))
2030            a = displ_mode[0] / np.linalg.norm(displ_mode[0])
2031            for d in displ_mode[1:]:
2032                b = d / np.linalg.norm(d)
2033                if np.dot(a, b) < threshold:
2034                    break
2035            else:
2036                indices.append(mode)
2037
2038        if len(indices) != 3:
2039            if raise_on_no_indices:
2040                raise RuntimeError('wrong number of indices: {}'.format(indices))
2041            else:
2042                indices = [0, 1, 2]
2043
2044        return indices
2045
2046    def asr_breaking(self, units='eV', threshold=0.95, raise_on_no_indices=True):
2047        """
2048        Calculates the breaking of the acoustic sum rule.
2049        Requires the presence of Gamma.
2050
2051        Args:
2052            units: Units for the output. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
2053                Case-insensitive.
2054            threshold: fractional value allowed for the matching of the displacements to identify acoustic modes.
2055            raise_on_no_indices: if True a RuntimeError will be raised if the acoustic mode will not be
2056                correctly identified
2057
2058        Returns:
2059            A namedtuple with:
2060                the three breaking of the acoustic modes
2061                the maximum breaking with sign
2062                the absolute value of the maximum breaking
2063        """
2064        gamma_ind = self.qpoints.index((0, 0, 0))
2065        ind = self.acoustic_indices(gamma_ind, threshold=threshold, raise_on_no_indices=raise_on_no_indices)
2066        asr_break = self.phfreqs[0, ind] * abu.phfactor_ev2units(units)
2067
2068        imax = np.argmax(asr_break)
2069
2070        return dict2namedtuple(breakings=asr_break, max_break=asr_break[imax], absmax_break=abs(asr_break[imax]))
2071
2072    def get_frozen_phonons(self, qpoint, nmode, eta=1, scale_matrix=None, max_supercell=None):
2073        """
2074        Creates a supercell with displaced atoms for the specified q-point and mode.
2075
2076        Args:
2077            qpoint: q vector in reduced coordinate in reciprocal space or index of the qpoint.
2078            nmode: number of the mode.
2079            eta: pre-factor multiplying the displacement. Gives the value in Angstrom of the
2080                largest displacement.
2081            scale_matrix: the scaling matrix of the supercell. If None a scaling matrix suitable for
2082                the qpoint will be determined.
2083            max_supercell: mandatory if scale_matrix is None, ignored otherwise. Defines the largest
2084                supercell in the search for a scaling matrix suitable for the q point.
2085
2086        Returns:
2087            A namedtuple with a Structure with the displaced atoms, a numpy array containing the
2088            displacements applied to each atom and the scale matrix used to generate the supercell.
2089        """
2090        qind = self.qindex(qpoint)
2091        displ = self.phdispl_cart[qind, nmode].reshape((-1, 3))
2092
2093        return self.structure.frozen_phonon(qpoint=self.qpoints[qind].frac_coords, displ=displ, eta=eta,
2094                                            frac_coords=False, scale_matrix=scale_matrix, max_supercell=max_supercell)
2095
2096    def get_longitudinal_fraction(self, qpoint, idir=None):
2097        """
2098        Calculates "longitudinal" fraction of the eigendisplacements.
2099
2100        Args:
2101            qpoint: q vector in reduced coordinate in reciprocal space or index of the qpoint.
2102            idir: an integer with the index of the non analytical direction if qpoint is gamma.
2103                If None all will be given.
2104
2105        Returns:
2106            A numpy array with the longitudinal fractions for each mode of the specified q point.
2107            If qpoint is gamma and idir is None it will be a numpy array with all the non analytical
2108            directions.
2109        """
2110        qind = self.qindex(qpoint)
2111        qpoint = self.qpoints[qind]
2112
2113        def get_fraction(direction, displ):
2114            displ = np.real(displ)
2115            # Normalization. Such that \sum_i dot(q, displ[i]) <= 1
2116            # and = 1 if q is parallel to displ[i] for each i.
2117            displ_norm = np.sum(np.linalg.norm(displ, axis=-1), axis=-1)
2118            displ = displ / displ_norm[:, None, None]
2119            versor = direction / np.linalg.norm(direction)
2120            return np.absolute(np.dot(displ, versor)).sum(axis=-1)
2121
2122        if qpoint.is_gamma():
2123            if self.non_anal_phdispl_cart is None:
2124                raise RuntimeError("Cannot calculate the lo/to fraction at Gamma if the non analytical"
2125                                   "contributions have not been calculated.")
2126            phdispl = self.non_anal_phdispl_cart.reshape((len(self.non_anal_directions), self.num_branches, self.num_atoms, 3))
2127            if idir is None:
2128                fractions = []
2129                for non_anal_dir, phd in zip(self.non_anal_directions, phdispl):
2130                    fractions.append(get_fraction(non_anal_dir, phd))
2131                return np.array(fractions)
2132            else:
2133                return get_fraction(self.non_anal_directions[idir], phdispl[idir])
2134        else:
2135            phdispl = self.phdispl_cart[qind].reshape((self.num_branches, self.num_atoms, 3))
2136            return get_fraction(qpoint.cart_coords, phdispl)
2137
2138    @add_fig_kwargs
2139    def plot_longitudinal_fraction(self, qpoint, idir=None, ax_list=None, units="eV", branches=None,
2140                                   format_w="%.3f", fontsize=10, **kwargs):
2141        """
2142        Plots an histogram "longitudinal" fraction of the eigendisplacements.
2143
2144        Args:
2145            qpoint: q vector in reduced coordinate in reciprocal space or index of the qpoint.
2146            idir: an integer with the index of the non analytical direction if qpoint is gamma.
2147                If None all will be plot.
2148            ax_list: The axes for the plot. If ax_list is None, a new figure is created and
2149                the axes are automatically generated.
2150            units: Units for the output. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
2151                Case-insensitive.
2152            branches: list of indices for the modes that should be represented. If None all the modes will be shown.
2153            format_w: string used to format the values of the frequency. Default "%.3f".
2154            fontsize: Labels and title fontsize.
2155
2156        Returns:
2157            |matplotlib-Figure|
2158
2159        """
2160        qind = self.qindex(qpoint)
2161        qpoint = self.qpoints[qind]
2162        fractions = self.get_longitudinal_fraction(qind, idir)
2163
2164        factor = abu.phfactor_ev2units(units)
2165
2166        if branches is None:
2167            branches = self.branches
2168        elif not isinstance(branches, (list, tuple)):
2169            branches = [branches]
2170
2171        is_non_anal = qpoint.is_gamma()
2172
2173        # if non analytical directions at gamma the
2174        if len(fractions.shape) == 1:
2175            fractions = [fractions]
2176
2177        ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=len(fractions), ncols=1,
2178                                                sharex=False, sharey=False, squeeze=False)
2179
2180        width, pad = 4, 1
2181        pad = width + pad
2182
2183        for i, ax in enumerate(ax_list.ravel()):
2184            xticks, xticklabels = [], []
2185            x = 0
2186            if idir is not None:
2187                i_ref = idir
2188            else:
2189                i_ref = i
2190            for inu, nu in enumerate(branches):
2191                height = fractions[i][nu]
2192                ax.bar(x, height, width, 0, align="center",
2193                       color="r", edgecolor='black')
2194
2195                xticks.append(x)
2196                if is_non_anal:
2197                    w_qnu = self.non_anal_phfreqs[i_ref, nu] * factor
2198                else:
2199                    w_qnu = self.phfreqs[qind, nu] * factor
2200                xticklabels.append(format_w % w_qnu)
2201
2202                x += (width + pad) / 2
2203
2204            if is_non_anal:
2205                # no title for multiple axes, not enough space.
2206                if idir is not None:
2207                    ax.set_title(f"q-direction = {self.non_anal_directions[i_ref]}", fontsize=fontsize)
2208            else:
2209                ax.set_title(f"qpoint = {repr(qpoint)}", fontsize=fontsize)
2210
2211            ax.set_ylabel(r"Longitudinal fraction", fontsize=fontsize)
2212            ax.set_ylim(0, 1)
2213
2214            ax.set_xticks(xticks)
2215            ax.set_xticklabels((xticklabels))
2216
2217            if i == len(fractions) - 1:
2218                ax.set_xlabel(f'Frequency {abu.phunit_tag(units)}')
2219
2220        return fig
2221
2222    @add_fig_kwargs
2223    def plot_longitudinal_fatbands(self, ax=None, units="eV", qlabels=None, branch_range=None, match_bands=False,
2224                                   sum_degenerate=False, factor=1, **kwargs):
2225        r"""
2226        Plot the phonon band structure with width representing the longitudinal fraction of the fatbands.
2227
2228        Args:
2229            ax: |matplotlib-Axes| or None if a new figure should be created.
2230            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
2231                Case-insensitive.
2232            qlabels: dictionary whose keys are tuples with the reduced coordinates of the q-points.
2233                The values are the labels. e.g. ``qlabels = {(0.0,0.0,0.0): "$\Gamma$", (0.5,0,0): "L"}``.
2234            branch_range: Tuple specifying the minimum and maximum branch index to plot (default: all branches are plotted).
2235            match_bands: if True the bands will be matched based on the scalar product between the eigenvectors.
2236            sum_degenerate: if True modes with similar frequencies will be considered as degenerated and their
2237                contributions will be summed (squared sum). Notice that this may end up summing contributions
2238                from modes that are just accidentally degenerated.
2239            factor: a float that will used to scale the width of the fatbands.
2240
2241        Returns:
2242            |matplotlib-Figure|
2243        """
2244        # Select the band range.
2245        if branch_range is None:
2246            branch_range = range(self.num_branches)
2247        else:
2248            branch_range = range(branch_range[0], branch_range[1], 1)
2249
2250        ax, fig, plt = get_ax_fig_plt(ax=ax)
2251
2252        # Decorate the axis (e.g add ticks and labels).
2253        self.decorate_ax(ax, units=units, qlabels=qlabels)
2254
2255        if "color" not in kwargs: kwargs["color"] = "black"
2256        if "linewidth" not in kwargs: kwargs["linewidth"] = 1.0
2257
2258        first_xx = 0
2259
2260        units_factor = abu.phfactor_ev2units(units)
2261
2262        for i, (q_l, pf_l) in enumerate(zip(self.split_qpoints, self.split_phfreqs)):
2263            if match_bands:
2264                ind = self.split_matched_indices[i]
2265                pf_l = pf_l[np.arange(len(pf_l))[:, None], ind]
2266            pf_l = pf_l * units_factor
2267            xx = list(range(first_xx, first_xx + len(pf_l)))
2268            for branch in branch_range:
2269                ax.plot(xx, pf_l[:, branch], **kwargs)
2270            first_xx = xx[-1]
2271
2272            width = []
2273            for iq, (q, pf) in enumerate(zip(q_l, pf_l)):
2274
2275                print(q)
2276                if np.allclose(np.mod(q, 1), [0, 0, 0]):
2277                    if self.non_anal_ph is not None:
2278                        if iq == 0:
2279                            direction = q_l[iq+1]
2280                        else:
2281                            direction = q_l[iq-1]
2282                        idir = self.non_anal_ph.index_direction(direction)
2283                        frac = self.get_longitudinal_fraction(q, idir)
2284                    else:
2285                        frac = np.zeros(self.num_branches)
2286                else:
2287                    frac = self.get_longitudinal_fraction(q)
2288
2289                # sum the contributions from degenerate modes
2290                if sum_degenerate:
2291                    pf_round = pf.round(decimals=int(6 * units_factor))
2292                    partitioned_pf = [np.where(pf_round == element)[0].tolist() for element in np.unique(pf_round)]
2293                    for group in partitioned_pf:
2294                        if len(group) > 1:
2295                            frac[group[0]] = np.linalg.norm(frac[group])
2296                            frac[group[1:]] = 0
2297
2298                if match_bands:
2299                    ind = self.split_matched_indices[i]
2300                    frac = frac[ind[iq]]
2301
2302                width.append(frac * units_factor * factor / 600)
2303
2304            width = np.array(width)
2305            for branch in branch_range:
2306                ax.fill_between(xx, pf_l[:, branch] + width[:, branch], pf_l[:, branch] - width[:, branch],
2307                                facecolor="r", alpha=0.4, linewidth=0)
2308
2309        return fig
2310
2311    @add_fig_kwargs
2312    def plot_qpt_distance(self, qpt_list=None, ngqpt=None, shiftq=(0, 0, 0), plot_distances=False,
2313                          units="eV", qlabels=None, branch_range=None, colormap="viridis_r",
2314                          match_bands=False, log_scale=False, **kwargs):
2315        r"""
2316        Plot the phonon band structure coloring the point according to the minimum distance of
2317        the qpoints of the path from a list of qpoints. This can be for example defined as the
2318        q-points effectively calculated in DFPT.
2319        Optionally plot the explicit values.
2320
2321        Args:
2322            qpt_list: list of fractional coordinates or KpointList of the qpoints from which the minimum
2323                distance will be calculated.
2324            ngqpt: the division of a regular grid of qpoints. Used to automatically fill in the qpt_list
2325                based on abipy.core.kpoints.kmesh_from_mpdivs.
2326            shiftq: the shifts of a regular grid of qpoints. Used to automatically fill in the qpt_list
2327                based on abipy.core.kpoints.kmesh_from_mpdivs.
2328            plot_distances: if True a second plot will be added with the explicit values of the distances.
2329            ax: |matplotlib-Axes| or None if a new figure should be created.
2330            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
2331                Case-insensitive.
2332            qlabels: dictionary whose keys are tuples with the reduced coordinates of the q-points.
2333                The values are the labels. e.g. ``qlabels = {(0.0,0.0,0.0): "$\Gamma$", (0.5,0,0): "L"}``.
2334            branch_range: Tuple specifying the minimum and maximum branch_i index to plot
2335                (default: all branches are plotted).
2336            colormap: matplotlib colormap to determine the colors available.
2337                http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html
2338            match_bands: if True the bands will be matched based on the scalar product between the eigenvectors.
2339            log_scale: if True the values will be plotted in a log scale.
2340
2341        Returns: |matplotlib-Figure|
2342        """
2343        from matplotlib.collections import LineCollection
2344
2345        if qpt_list is None:
2346            if ngqpt is None:
2347                raise ValueError("at least one among qpt_list and ngqpt should be provided")
2348            qpt_list = kmesh_from_mpdivs(ngqpt, shiftq, pbc=False, order="bz")
2349
2350        if isinstance(qpt_list, KpointList):
2351            qpt_list = qpt_list.frac_coords
2352
2353        # Select the band range.
2354        if branch_range is None:
2355            branch_range = range(self.num_branches)
2356        else:
2357            branch_range = range(branch_range[0], branch_range[1], 1)
2358
2359        nrows = 2 if plot_distances else 1
2360        ncols = 1
2361        ax_list, fig, plt = get_axarray_fig_plt(ax_array=None, nrows=nrows, ncols=ncols,
2362                                                sharex=True, sharey=False, squeeze=True)
2363
2364        # make a list in case of only one plot
2365        if not plot_distances:
2366            ax_list = [ax_list]
2367
2368        # Decorate the axis (e.g add ticks and labels).
2369        self.decorate_ax(ax_list[-1], units=units, qlabels=qlabels)
2370
2371        first_xx = 0
2372        factor = abu.phfactor_ev2units(units)
2373
2374        linewidth = 2
2375        if "lw" in kwargs:
2376            linewidth = kwargs.pop("lw")
2377        elif "linewidth" in kwargs:
2378            linewidth = kwargs.pop("linewidth")
2379
2380        rec_latt = self.structure.reciprocal_lattice
2381
2382        # calculate all the value to set the color normalization
2383        split_min_dist = []
2384        for i, q_l in enumerate(self.split_qpoints):
2385            all_dist = rec_latt.get_all_distances(q_l, qpt_list)
2386            split_min_dist.append(np.min(all_dist, axis=-1))
2387
2388        if log_scale:
2389            import matplotlib
2390            # find the minimum value larger than zero and set the 0 to that value
2391            min_value = np.min([v for l in split_min_dist for v in l if v > 0])
2392            for min_list in split_min_dist:
2393                min_list[min_list == 0] = min_value
2394            norm = matplotlib.colors.LogNorm(min_value, np.max(split_min_dist), clip=True)
2395        else:
2396            norm = plt.Normalize(np.min(split_min_dist), np.max(split_min_dist))
2397
2398        segments = []
2399        total_min_dist = []
2400
2401        for i, (pf, min_dist) in enumerate(zip(self.split_phfreqs, split_min_dist)):
2402            if match_bands:
2403                ind = self.split_matched_indices[i]
2404                pf = pf[np.arange(len(pf))[:, None], ind]
2405            pf = pf * factor
2406            xx = range(first_xx, first_xx + len(pf))
2407
2408            for branch_i in branch_range:
2409                points = np.array([xx, pf[:, branch_i]]).T.reshape(-1, 1, 2)
2410                segments.append(np.concatenate([points[:-1], points[1:]], axis=1))
2411                total_min_dist.extend(min_dist[:-1])
2412
2413            first_xx = xx[-1]
2414
2415        segments = np.concatenate(segments)
2416        total_min_dist = np.array(total_min_dist)
2417
2418        lc = LineCollection(segments, cmap=colormap, norm=norm)
2419        lc.set_array(total_min_dist)
2420        lc.set_linewidth(linewidth)
2421
2422        line = ax_list[-1].add_collection(lc)
2423
2424        # line collection does not autoscale the plot
2425        ax_list[-1].set_ylim(np.min(self.split_phfreqs), np.max(self.split_phfreqs))
2426
2427        fig.colorbar(line, ax=ax_list)
2428
2429        if plot_distances:
2430            first_xx = 0
2431            for i, (q_l, min_dist) in enumerate(zip(self.split_qpoints, split_min_dist)):
2432                xx = list(range(first_xx, first_xx + len(q_l)))
2433                ax_list[0].plot(xx, min_dist, linewidth=linewidth, **kwargs)
2434
2435                first_xx = xx[-1]
2436            ax_list[0].grid(True)
2437        return fig
2438
2439
2440class PHBST_Reader(ETSF_Reader):
2441    """
2442    This object reads data from PHBST.nc file produced by anaddb.
2443
2444    .. rubric:: Inheritance Diagram
2445    .. inheritance-diagram:: PHBST_Reader
2446    """
2447
2448    def read_qredcoords(self):
2449        """Array with the reduced coordinates of the q-points."""
2450        return self.read_value("qpoints")
2451
2452    def read_qweights(self):
2453        """The weights of the q-points"""
2454        return self.read_value("qweights")
2455
2456    def read_phfreqs(self):
2457        """|numpy-array| with the phonon frequencies in eV."""
2458        return self.read_value("phfreqs")
2459
2460    def read_phdispl_cart(self):
2461        """
2462        Complex array with the Cartesian displacements in **Angstrom**
2463        shape is [num_qpoints,  mu_mode,  cart_direction].
2464        """
2465        return self.read_value("phdispl_cart", cmode="c")
2466
2467    def read_amu(self):
2468        """The atomic mass units"""
2469        return self.read_value("atomic_mass_units", default=None)
2470
2471    def read_epsinf_zcart(self):
2472        """
2473        Read and return electronic dielectric tensor and Born effective charges in Cartesian coordinates
2474        Return (None, None) if data is not available.
2475        """
2476        # nctkarr_t('emacro_cart', "dp", 'number_of_cartesian_directions, number_of_cartesian_directions')
2477        # nctkarr_t('becs_cart', "dp", "number_of_cartesian_directions, number_of_cartesian_directions, number_of_atoms")]
2478        epsinf = self.read_value("emacro_cart", default=None)
2479        if epsinf is not None: epsinf = epsinf.T.copy()
2480        zcart = self.read_value("becs_cart", default=None)
2481        if zcart is not None: zcart = zcart.transpose(0, 2, 1).copy()
2482        return epsinf, zcart
2483
2484
2485class PhbstFile(AbinitNcFile, Has_Structure, Has_PhononBands, NotebookWriter):
2486    """
2487    Object used to access data stored in the PHBST.nc file produced by ABINIT.
2488
2489    .. rubric:: Inheritance Diagram
2490    .. inheritance-diagram:: PhbstFile
2491    """
2492
2493    def __init__(self, filepath):
2494        """
2495        Args:
2496            path: path to the file
2497        """
2498        super().__init__(filepath)
2499        self.reader = PHBST_Reader(filepath)
2500
2501        # Initialize Phonon bands and add metadata from ncfile
2502        self._phbands = PhononBands.from_file(filepath)
2503
2504    def __str__(self):
2505        return self.to_string()
2506
2507    def to_string(self, verbose=0):
2508        """
2509        String representation
2510
2511        Args:
2512            verbose: verbosity level.
2513        """
2514        lines = []; app = lines.append
2515
2516        app(marquee("File Info", mark="="))
2517        app(self.filestat(as_string=True))
2518        app("")
2519
2520        app(self.phbands.to_string(title=None, with_structure=True, with_qpoints=False, verbose=verbose))
2521
2522        return "\n".join(lines)
2523
2524    @property
2525    def structure(self):
2526        """|Structure| object"""
2527        return self.phbands.structure
2528
2529    @property
2530    def qpoints(self):
2531        """List of q-point objects."""
2532        return self.phbands.qpoints
2533
2534    @property
2535    def phbands(self):
2536        """|PhononBands| object"""
2537        return self._phbands
2538
2539    def close(self):
2540        """Close the file."""
2541        self.reader.close()
2542
2543    @lazy_property
2544    def params(self):
2545        """:class:`OrderedDict` with parameters that might be subject to convergence studies."""
2546        od = self.get_phbands_params()
2547        return od
2548
2549    def qindex(self, qpoint):
2550        """
2551        Returns the index of the qpoint in the PhbstFile.
2552        Accepts integer, vector with reduced coordinates or |Kpoint|.
2553        """
2554        return self.phbands.qindex(qpoint)
2555
2556    def qindex_qpoint(self, qpoint, is_non_analytical_direction=False):
2557        """
2558        Returns (qindex, qpoint).
2559        Accepts integer, vector with reduced coordinates or |Kpoint|.
2560        """
2561        return self.phbands.qindex_qpoint(qpoint, is_non_analytical_direction=is_non_analytical_direction)
2562
2563    def get_phframe(self, qpoint, with_structure=True):
2564        """
2565        Return a |pandas-DataFrame| with the phonon frequencies at the given q-point and
2566        information on the crystal structure (used for convergence studies).
2567
2568        Args:
2569            qpoint: integer, vector of reduced coordinates or |Kpoint| object.
2570            with_structure: True to add structural parameters.
2571        """
2572        qindex, qpoint = self.qindex_qpoint(qpoint)
2573        phfreqs = self.phbands.phfreqs
2574
2575        d = dict(
2576            omega=phfreqs[qindex, :],
2577            branch=list(range(3 * len(self.structure))),
2578        )
2579
2580        # Add geometrical information
2581        if with_structure:
2582            d.update(self.structure.get_dict4pandas(with_spglib=True))
2583
2584        # Build the pandas Frame and add the q-point as attribute.
2585        import pandas as pd
2586        frame = pd.DataFrame(d, columns=list(d.keys()))
2587        frame.qpoint = qpoint
2588
2589        return frame
2590
2591    def get_phmode(self, qpoint, branch):
2592        """
2593        Returns the :class:`PhononMode` with the given qpoint and branch nu.
2594
2595        Args:
2596            qpoint: Either a vector with the reduced components of the q-point
2597                or an integer giving the sequential index (C-convention).
2598            branch: branch index (C-convention)
2599
2600        Returns:
2601            :class:`PhononMode` instance.
2602        """
2603        qindex, qpoint = self.qindex_qpoint(qpoint)
2604
2605        return PhononMode(qpoint=qpoint,
2606                          freq=self.phbands.phfreqs[qindex, branch],
2607                          displ_cart=self.phbands.phdispl_cart[qindex, branch, :],
2608                          structure=self.structure)
2609
2610    def yield_figs(self, **kwargs):  # pragma: no cover
2611        """
2612        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
2613        """
2614        return self.yield_phbands_figs(**kwargs)
2615
2616    def write_notebook(self, nbpath=None):
2617        """
2618        Write an jupyter_ notebook to nbpath. If nbpath is None, a temporay file in the current
2619        working directory is created. Return path to the notebook.
2620        """
2621        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
2622
2623        nb.cells.extend([
2624            nbv.new_code_cell("ncfile = abilab.abiopen('%s')" % self.filepath),
2625            nbv.new_code_cell("print(ncfile)"),
2626            nbv.new_code_cell("ncfile.phbands.plot();"),
2627            nbv.new_code_cell("ncfile.phbands.qpoints.plot();"),
2628            #nbv.new_code_cell("ncfile.phbands.get_phdos().plot();"),
2629        ])
2630
2631        return self._write_nb_nbpath(nb, nbpath)
2632
2633
2634_THERMO_YLABELS = {  # [name][units] --> latex string
2635    "internal_energy": {"eV": "$U(T)$ (eV/cell)", "Jmol": "$U(T)$ (J/mole)"},
2636    "free_energy": {"eV": "$F(T) + ZPE$ (eV/cell)", "Jmol": "$F(T) + ZPE$ (J/mole)"},
2637    "entropy": {"eV": "$S(T)$ (eV/cell)", "Jmol": "$S(T)$ (J/mole)"},
2638    "cv": {"eV": "$C_V(T)$ (eV/cell)", "Jmol": "$C_V(T)$ (J/mole)"},
2639}
2640
2641
2642class PhononDos(Function1D):
2643    """
2644    This object stores the phonon density of states.
2645    An instance of ``PhononDos`` has a ``mesh`` (numpy array with the points of the mesh)
2646    and another numpy array, ``values``, with the DOS on the mesh.
2647
2648    .. note::
2649
2650        mesh is given in eV, values are in states/eV.
2651    """
2652
2653    @classmethod
2654    def as_phdos(cls, obj, phdos_kwargs=None):
2655        """
2656        Return an instance of |PhononDos| from a generic obj. Supports::
2657
2658            - instances of cls
2659            - files (string) that can be open with abiopen and that provide one of the following attributes: [`phdos`, `phbands`]
2660            - instances of |PhononBands|.
2661            - objects providing a ``phbands`` attribute.
2662
2663        Args:
2664            phdos_kwargs: optional dictionary with the options passed to ``get_phdos`` to compute the phonon DOS.
2665            Used when obj is not already an instance of `cls` or when we have to compute the DOS from obj.
2666        """
2667        if phdos_kwargs is None: phdos_kwargs = {}
2668
2669        if isinstance(obj, cls):
2670            return obj
2671
2672        elif is_string(obj):
2673            # path? (pickle or file supported by abiopen)
2674            if obj.endswith(".pickle"):
2675                with open(obj, "rb") as fh:
2676                    return cls.as_phdos(pickle.load(fh), phdos_kwargs)
2677
2678            from abipy.abilab import abiopen
2679            with abiopen(obj) as abifile:
2680                if hasattr(abifile, "phdos"):
2681                    return abifile.phdos
2682                elif hasattr(abifile, "phbands"):
2683                    return abifile.phbands.get_phdos(**phdos_kwargs)
2684                else:
2685                    raise TypeError("Don't know how to create `PhononDos` from type: %s" % type(abifile))
2686
2687        elif isinstance(obj, PhononBands):
2688            return obj.get_phdos(**phdos_kwargs)
2689
2690        elif hasattr(obj, "phbands"):
2691            return obj.phbands.get_phdos(**phdos_kwargs)
2692
2693        elif hasattr(obj, "phdos"):
2694            return obj.phdos
2695
2696        raise TypeError("Don't know how to create PhononDos object from type: `%s`" % type(obj))
2697
2698    @lazy_property
2699    def iw0(self):
2700        """
2701        Index of the first point in the mesh whose value is >= 0
2702        """
2703        iw0 = self.find_mesh_index(0.0)
2704        if iw0 == -1:
2705            raise ValueError("Cannot find zero in energy mesh")
2706        return iw0
2707
2708    @lazy_property
2709    def idos(self):
2710        """Integrated DOS."""
2711        return self.integral()
2712
2713    @lazy_property
2714    def zero_point_energy(self):
2715        """Zero point energy in eV per unit cell."""
2716        iw0 = self.iw0
2717        return Energy(0.5 * np.trapz(self.mesh[iw0:] * self.values[iw0:], x=self.mesh[iw0:]), "eV")
2718
2719    def plot_dos_idos(self, ax, what="d", exchange_xy=False, units="eV", **kwargs):
2720        """
2721        Helper function to plot DOS/IDOS on the axis ``ax``.
2722
2723        Args:
2724            ax: |matplotlib-Axes|
2725            what: string selecting the quantity to plot:
2726                "d" for DOS, "i" for IDOS. chars can be concatenated
2727                hence what="id" plots both IDOS and DOS. (default "d").
2728            exchange_xy: True to exchange axis
2729            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
2730                Case-insensitive.
2731            kwargs: Options passed to matplotlib plot method.
2732
2733        Return:
2734            list of lines added to the plot.
2735        """
2736        opts = [c.lower() for c in what]
2737        lines = []
2738
2739        for c in opts:
2740            f = {"d": self, "i": self.idos}[c]
2741            xfactor = abu.phfactor_ev2units(units)
2742            # Don't rescale IDOS
2743            yfactor = 1 / xfactor if c == "d" else 1
2744
2745            ls = f.plot_ax(ax, exchange_xy=exchange_xy, xfactor=xfactor, yfactor=yfactor, **kwargs)
2746            lines.extend(ls)
2747
2748        return lines
2749
2750    # TODO: This should be called plot_dos_idos!
2751    @add_fig_kwargs
2752    def plot(self, units="eV", **kwargs):
2753        """
2754        Plot Phonon DOS and IDOS on two distict plots.
2755
2756        Args:
2757            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
2758                Case-insensitive.
2759            kwargs: Keyword arguments passed to :mod:`matplotlib`.
2760
2761        Returns: |matplotlib-Figure|
2762        """
2763        import matplotlib.pyplot as plt
2764        from matplotlib.gridspec import GridSpec
2765
2766        fig = plt.figure()
2767        gspec = GridSpec(2, 1, height_ratios=[1, 2], wspace=0.05)
2768        ax1 = plt.subplot(gspec[0])
2769        ax2 = plt.subplot(gspec[1])
2770
2771        for ax in (ax1, ax2):
2772            ax.grid(True)
2773
2774        ax2.set_xlabel('Energy %s' % abu.phunit_tag(units))
2775        ax1.set_ylabel("IDOS (states)")
2776        ax2.set_ylabel("DOS %s" % abu.phdos_label_from_units(units))
2777
2778        self.plot_dos_idos(ax1, what="i", units=units, **kwargs)
2779        self.plot_dos_idos(ax2, what="d", units=units, **kwargs)
2780
2781        return fig
2782
2783    def get_internal_energy(self, tstart=5, tstop=300, num=50):
2784        """
2785        Returns the internal energy, in eV, in the harmonic approximation for different temperatures
2786        Zero point energy is included.
2787
2788        Args:
2789            tstart: The starting value (in Kelvin) of the temperature mesh.
2790            tstop: The end value (in Kelvin) of the mesh.
2791            num (int): optional Number of samples to generate. Default is 50.
2792
2793        Return: |Function1D| object with U(T) + ZPE.
2794        """
2795        tmesh = np.linspace(tstart, tstop, num=num)
2796        w, gw = self.mesh[self.iw0:], self.values[self.iw0:]
2797        if w[0] < 1e-12:
2798            w, gw = self.mesh[self.iw0+1:], self.values[self.iw0+1:]
2799        coth = lambda x: 1.0 / np.tanh(x)
2800
2801        vals = np.empty(len(tmesh))
2802        for it, temp in enumerate(tmesh):
2803            if temp == 0:
2804                vals[it] = self.zero_point_energy
2805            else:
2806                wd2kt = w / (2 * abu.kb_eVK * temp)
2807                vals[it] = 0.5 * np.trapz(w * coth(wd2kt) * gw, x=w)
2808            #print(vals[it])
2809
2810        return Function1D(tmesh, vals)
2811
2812    def get_entropy(self, tstart=5, tstop=300, num=50):
2813        """
2814        Returns the entropy, in eV/K, in the harmonic approximation for different temperatures
2815
2816        Args:
2817            tstart: The starting value (in Kelvin) of the temperature mesh.
2818            tstop: The end value (in Kelvin) of the mesh.
2819            num (int): optional Number of samples to generate. Default is 50.
2820
2821        Return: |Function1D| object with S(T).
2822        """
2823        tmesh = np.linspace(tstart, tstop, num=num)
2824        w, gw = self.mesh[self.iw0:], self.values[self.iw0:]
2825        if w[0] < 1e-12:
2826            w, gw = self.mesh[self.iw0+1:], self.values[self.iw0+1:]
2827        coth = lambda x: 1.0 / np.tanh(x)
2828
2829        vals = np.empty(len(tmesh))
2830        for it, temp in enumerate(tmesh):
2831            if temp == 0:
2832                vals[it] = 0
2833            else:
2834                wd2kt = w / (2 * abu.kb_eVK * temp)
2835                vals[it] = np.trapz((wd2kt * coth(wd2kt) - np.log(2 * np.sinh(wd2kt))) * gw, x=w)
2836
2837        return Function1D(tmesh, abu.kb_eVK * vals)
2838
2839    def get_free_energy(self, tstart=5, tstop=300, num=50):
2840        """
2841        Returns the free energy, in eV, in the harmonic approximation for different temperatures
2842        Zero point energy is included.
2843
2844        Args:
2845            tstart: The starting value (in Kelvin) of the temperature mesh.
2846            tstop: The end value (in Kelvin) of the mesh.
2847            num (int): optional Number of samples to generate. Default is 50.
2848
2849        Return: |Function1D| object with F(T) = U(T) + ZPE - T x S(T)
2850        """
2851        uz = self.get_internal_energy(tstart=tstart, tstop=tstop, num=num)
2852        s = self.get_entropy(tstart=tstart, tstop=tstop, num=num)
2853
2854        return Function1D(uz.mesh, uz.values - s.mesh * s.values)
2855
2856    def get_cv(self, tstart=5, tstop=300, num=50):
2857        """
2858        Returns the constant-volume specific heat, in eV/K, in the harmonic approximation
2859        for different temperatures
2860
2861        Args:
2862            tstart: The starting value (in Kelvin) of the temperature mesh.
2863            tstop: The end value (in Kelvin) of the mesh.
2864            num (int): optional Number of samples to generate. Default is 50.
2865
2866        Return: |Function1D| object with C_v(T).
2867        """
2868        tmesh = np.linspace(tstart, tstop, num=num)
2869        w, gw = self.mesh[self.iw0:], self.values[self.iw0:]
2870        if w[0] < 1e-12:
2871            w, gw = self.mesh[self.iw0+1:], self.values[self.iw0+1:]
2872        csch2 = lambda x: 1.0 / (np.sinh(x) ** 2)
2873
2874        vals = np.empty(len(tmesh))
2875        for it, temp in enumerate(tmesh):
2876            if temp == 0:
2877                vals[it] = 0
2878            else:
2879                wd2kt = w / (2 * abu.kb_eVK * temp)
2880                vals[it] = np.trapz(wd2kt ** 2 * csch2(wd2kt) * gw, x=w)
2881
2882        return Function1D(tmesh, abu.kb_eVK * vals)
2883
2884    @add_fig_kwargs
2885    def plot_harmonic_thermo(self, tstart=5, tstop=300, num=50, units="eV", formula_units=None,
2886                             quantities=None, fontsize=8, **kwargs):
2887        """
2888        Plot thermodynamic properties from the phonon DOSes within the harmonic approximation.
2889
2890        Args:
2891            tstart: The starting value (in Kelvin) of the temperature mesh.
2892            tstop: The end value (in Kelvin) of the mesh.
2893            num: int, optional Number of samples to generate. Default is 50.
2894            quantities: List of strings specifying the thermodynamic quantities to plot.
2895                Possible values: ["internal_energy", "free_energy", "entropy", "c_v"].
2896                None means all.
2897            units: eV for energies in ev/unit_cell, Jmol for results in J/mole.
2898            formula_units: the number of formula units per unit cell. If unspecified, the
2899                thermodynamic quantities will be given on a per-unit-cell basis.
2900            fontsize: Legend and title fontsize.
2901
2902        Returns: |matplotlib-Figure|
2903        """
2904        quantities = list_strings(quantities) if quantities is not None else \
2905            ["internal_energy", "free_energy", "entropy", "cv"]
2906
2907        # Build grid of plots.
2908        ncols, nrows = 1, 1
2909        num_plots = len(quantities)
2910        if num_plots > 1:
2911            ncols = 2
2912            nrows = num_plots // ncols + num_plots % ncols
2913
2914        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
2915                                               sharex=True, sharey=False, squeeze=False)
2916        # don't show the last ax if num_plots is odd.
2917        if num_plots % ncols != 0: ax_mat[-1, -1].axis("off")
2918
2919        for iax, (qname, ax) in enumerate(zip(quantities, ax_mat.flat)):
2920            irow, icol = divmod(iax, ncols)
2921            # Compute thermodynamic quantity associated to qname.
2922            f1d = getattr(self, "get_" + qname)(tstart=tstart, tstop=tstop, num=num)
2923            ys = f1d.values
2924            if formula_units is not None: ys /= formula_units
2925            if units == "Jmol": ys = ys * abu.e_Cb * abu.Avogadro
2926            ax.plot(f1d.mesh, ys)
2927
2928            ax.set_title(qname, fontsize=fontsize)
2929            ax.grid(True)
2930            ax.set_xlabel("Temperature (K)", fontsize=fontsize)
2931            ax.set_ylabel(_THERMO_YLABELS[qname][units], fontsize=fontsize)
2932            #ax.legend(loc="best", fontsize=fontsize, shadow=True)
2933
2934            if irow != nrows:
2935                set_visible(ax, False, "xlabel")
2936
2937        return fig
2938
2939    def to_pymatgen(self):
2940        """
2941        Creates a pymatgen :class:`PmgPhononDos` object
2942        """
2943        factor = abu.phfactor_ev2units("thz")
2944
2945        return PmgPhononDos(self.mesh * factor, self.values / factor)
2946
2947    @property
2948    def debye_temp(self):
2949        """
2950        Debye temperature in K.
2951        """
2952        integrals = (self * self.mesh ** 2).spline_integral() / self.spline_integral()
2953        t_d = np.sqrt(5/3*integrals)/abu.kb_eVK
2954
2955        return t_d
2956
2957    def get_acoustic_debye_temp(self, nsites):
2958        """
2959        Acoustic Debye temperature in K, i.e. the Debye temperature divided by nsites**(1/3).
2960
2961        Args:
2962            nsites: the number of sites in the cell.
2963        """
2964        return self.debye_temp/nsites**(1/3)
2965
2966
2967class PhdosReader(ETSF_Reader):
2968    """
2969    This object reads data from the PHDOS.nc file produced by anaddb.
2970
2971    .. note::
2972
2973            Frequencies are in eV, DOSes are in states/eV.
2974    """
2975    @lazy_property
2976    def structure(self):
2977        """|Structure| object."""
2978        return self.read_structure()
2979
2980    @lazy_property
2981    def wmesh(self):
2982        """The frequency mesh for the PH-DOS in eV."""
2983        return self.read_value("wmesh")
2984
2985    def read_pjdos_type(self):
2986        """[ntypat, nomega] array with Phonon DOS projected over atom types."""
2987        return self.read_value("pjdos_type")
2988
2989    def read_pjdos_atdir(self):
2990        """
2991        Return [natom, three, nomega] array with Phonon DOS projected over atoms and cartesian directions.
2992        """
2993        return self.read_value("pjdos")
2994
2995    def read_phdos(self):
2996        """Return |PhononDos| object with the total phonon DOS"""
2997        return PhononDos(self.wmesh, self.read_value("phdos"))
2998
2999    def read_pjdos_symbol_xyz_dict(self):
3000        """
3001        Return :class:`OrderedDict` mapping element symbol --> [3, nomega] array
3002        with the the phonon DOSes summed over atom-types and decomposed along
3003        the three cartesian directions.
3004        """
3005        # The name is a bit confusing: rc stands for "real-space cartesian"
3006        # phdos_rc_type[ntypat, 3, nomega]
3007        values = self.read_value("pjdos_rc_type")
3008
3009        od = OrderedDict()
3010        for symbol in self.chemical_symbols:
3011            type_idx = self.typeidx_from_symbol(symbol)
3012            od[symbol] = values[type_idx]
3013
3014        return od
3015
3016    def read_pjdos_symbol_dict(self):
3017        """
3018        Ordered dictionary mapping element symbol --> |PhononDos|
3019        where PhononDos is the contribution to the total DOS summed over atoms
3020        with chemical symbol ``symbol``.
3021        """
3022        # [ntypat, nomega] array with PH-DOS projected over atom types."""
3023        values = self.read_pjdos_type()
3024
3025        od = OrderedDict()
3026        for symbol in self.chemical_symbols:
3027            type_idx = self.typeidx_from_symbol(symbol)
3028            od[symbol] = PhononDos(self.wmesh, values[type_idx])
3029
3030        return od
3031
3032    def read_msq_dos(self):
3033        """
3034        Read generalized DOS with MSQ displacement tensor in cartesian coords.
3035
3036        Return: |MsqDos| object.
3037        """
3038        if "msqd_dos_atom" not in self.rootgrp.variables:
3039            raise RuntimeError("PHBST file does not contain `msqd_dos_atom` variable.\n" +
3040                               "Please use a more recent Abinit version >= 9")
3041
3042        # nctkarr_t('msqd_dos_atom', "dp", 'number_of_frequencies, three, three, number_of_atoms') &
3043        # symmetric tensor still transpose (3,3) to be consistent.
3044        values = self.read_value("msqd_dos_atom").transpose([0, 2, 1, 3]).copy()
3045
3046        # Read atomic masses and build dictionary element_symbol --> amu
3047        amu_symbol = self.read_amu_symbol()
3048
3049        from abipy.dfpt.msqdos import MsqDos
3050        return MsqDos(self.structure, self.wmesh, values, amu_symbol)
3051
3052
3053class PhdosFile(AbinitNcFile, Has_Structure, NotebookWriter):
3054    """
3055    Container object storing the different DOSes stored in the
3056    PHDOS.nc file produced by anaddb.
3057    Provides helper function to visualize/extract data.
3058
3059    .. rubric:: Inheritance Diagram
3060    .. inheritance-diagram:: PhdosFile
3061    """
3062
3063    def __init__(self, filepath):
3064        # Open the file, read data and create objects.
3065        super().__init__(filepath)
3066
3067        self.reader = r = PhdosReader(filepath)
3068        self.wmesh = r.wmesh
3069
3070    def close(self):
3071        """Close the file."""
3072        self.reader.close()
3073
3074    @lazy_property
3075    def params(self):
3076        """
3077        :class:`OrderedDict` with the convergence parameters
3078        Used to construct |pandas-DataFrames|.
3079        """
3080        return {}
3081        #od = OrderedDict([
3082        #    ("nsppol", self.nsppol),
3083        #])
3084        #return od
3085
3086    def __str__(self):
3087        """Invoked by str"""
3088        return self.to_string()
3089
3090    def to_string(self, verbose=0):
3091        """
3092        Human-readable string with useful information such as structure...
3093
3094        Args:
3095            verbose: Verbosity level.
3096        """
3097        lines = []; app = lines.append
3098
3099        app(marquee("File Info", mark="="))
3100        app(self.filestat(as_string=True))
3101        app("")
3102        app(self.structure.to_string(verbose=verbose, title="Structure"))
3103        app("")
3104
3105        return "\n".join(lines)
3106
3107    @lazy_property
3108    def structure(self):
3109        """|Structure| object."""
3110        return self.reader.structure
3111
3112    @lazy_property
3113    def phdos(self):
3114        """|PhononDos| object."""
3115        return self.reader.read_phdos()
3116
3117    @lazy_property
3118    def pjdos_symbol(self):
3119        """
3120        Ordered dictionary mapping element symbol --> `PhononDos`
3121        where PhononDos is the contribution to the total DOS summed over atoms
3122        with chemical symbol `symbol`.
3123        """
3124        return self.reader.read_pjdos_symbol_dict()
3125
3126    @lazy_property
3127    def msqd_dos(self):
3128        """
3129        |MsqDos| object with Mean square displacement tensor in cartesian coords.
3130        Allows one to calculate Debye Waller factors by integration with 1/omega and the Bose-Einstein factor.
3131        """
3132        return self.reader.read_msq_dos()
3133
3134    @add_fig_kwargs
3135    def plot_pjdos_type(self, units="eV", stacked=True, colormap="jet", alpha=0.7, exchange_xy=False,
3136                        ax=None, xlims=None, ylims=None, fontsize=12, **kwargs):
3137        """
3138        Plot type-projected phonon DOS.
3139
3140        Args:
3141            ax: |matplotlib-Axes| or None if a new figure should be created.
3142            stacked: True if DOS partial contributions should be stacked on top of each other.
3143            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
3144                Case-insensitive.
3145            colormap: Have a look at the colormaps
3146                `here <http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html>`_
3147                and decide which one you'd like:
3148            alpha: The alpha blending value, between 0 (transparent) and 1 (opaque).
3149            exchange_xy: True to exchange x-y axis.
3150            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
3151                   or scalar e.g. ``left``. If left (right) is None, default values are used.
3152            ylims: y-axis limits.
3153            fontsize: legend and title fontsize.
3154
3155        Returns: |matplotlib-Figure|
3156        """
3157        lw = kwargs.pop("lw", 2)
3158        factor = abu.phfactor_ev2units(units)
3159
3160        ax, fig, plt = get_ax_fig_plt(ax=ax)
3161        cmap = plt.get_cmap(colormap)
3162
3163        ax.grid(True)
3164        set_axlims(ax, xlims, "x")
3165        set_axlims(ax, ylims, "y")
3166        xlabel, ylabel = 'Frequency %s' % abu.phunit_tag(units), 'PJDOS %s' % abu.phdos_label_from_units(units)
3167        set_ax_xylabels(ax, xlabel, ylabel, exchange_xy)
3168
3169        # Type projected DOSes.
3170        num_plots = len(self.pjdos_symbol)
3171        cumulative = np.zeros(len(self.wmesh))
3172
3173        for i, (symbol, pjdos) in enumerate(self.pjdos_symbol.items()):
3174            x, y = pjdos.mesh * factor, pjdos.values / factor
3175            if exchange_xy: x, y = y, x
3176            if num_plots != 1:
3177                color = cmap(float(i) / (num_plots - 1))
3178            else:
3179                color = cmap(0.0)
3180
3181            if not stacked:
3182                ax.plot(x, y, lw=lw, label=symbol, color=color)
3183            else:
3184                if not exchange_xy:
3185                    ax.plot(x, cumulative + y, lw=lw, label=symbol, color=color)
3186                    ax.fill_between(x, cumulative, cumulative + y, facecolor=color, alpha=alpha)
3187                    cumulative += y
3188                else:
3189                    ax.plot(cumulative + x, y, lw=lw, label=symbol, color=color)
3190                    ax.fill_betweenx(y, cumulative, cumulative + x, facecolor=color, alpha=alpha)
3191                    cumulative += x
3192
3193        # Total PHDOS
3194        x, y = self.phdos.mesh * factor, self.phdos.values / factor
3195        if exchange_xy: x, y = y, x
3196        ax.plot(x, y, lw=lw, label="Total PHDOS", color='black')
3197        ax.legend(loc="best", fontsize=fontsize, shadow=True)
3198
3199        return fig
3200
3201    @add_fig_kwargs
3202    def plot_pjdos_cartdirs_type(self, units="eV", stacked=True, colormap="jet", alpha=0.7,
3203                                 xlims=None, ylims=None, ax_list=None, fontsize=8, **kwargs):
3204        """
3205        Plot type-projected phonon DOS decomposed along the three cartesian directions.
3206        Three rows for each cartesian direction. Each row shows the contribution of each atomic type + Total Phonon DOS.
3207
3208        Args:
3209            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
3210                Case-insensitive.
3211            stacked: True if DOS partial contributions should be stacked on top of each other.
3212            colormap: Have a look at the colormaps
3213                `here <http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html>`_
3214                and decide which one you'd like:
3215            alpha: The alpha blending value, between 0 (transparent) and 1 (opaque)
3216            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
3217                   or scalar e.g. ``left``. If left (right) is None, default values are used
3218            ylims: y-axis limits.
3219            ax_list: List of |matplotlib-Axes| or None if a new figure should be created.
3220            fontsize: Legend and label fontsize.
3221
3222        Returns: |matplotlib-Figure|.
3223        """
3224        lw = kwargs.pop("lw", 2)
3225        ntypat = self.structure.ntypesp
3226        factor = abu.phfactor_ev2units(units)
3227
3228        # Three rows for each direction.
3229        # Each row shows the contribution of each atomic type + Total PH DOS.
3230        nrows, ncols = 3, 1
3231        ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
3232                                                sharex=False, sharey=True, squeeze=True)
3233        ax_list = np.reshape(ax_list, (nrows, ncols)).ravel()
3234        cmap = plt.get_cmap(colormap)
3235
3236        # symbol --> [three, number_of_frequencies] in cart dirs
3237        pjdos_symbol_rc = self.reader.read_pjdos_symbol_xyz_dict()
3238
3239        xx = self.phdos.mesh * factor
3240        for idir, ax in enumerate(ax_list):
3241            ax.grid(True)
3242            set_axlims(ax, xlims, "x")
3243            set_axlims(ax, ylims, "y")
3244
3245            ax.set_ylabel(r'PJDOS along %s' % {0: "x", 1: "y", 2: "z"}[idir])
3246            if idir == 2:
3247                ax.set_xlabel('Frequency %s' % abu.phunit_tag(units))
3248
3249            # Plot Type projected DOSes along cartesian direction idir
3250            cumulative = np.zeros(len(self.wmesh))
3251            for itype, symbol in enumerate(self.reader.chemical_symbols):
3252                color = cmap(float(itype) / max(1, ntypat - 1))
3253                yy = pjdos_symbol_rc[symbol][idir] / factor
3254
3255                if not stacked:
3256                    ax.plot(xx, yy, label=symbol, color=color)
3257                else:
3258                    ax.plot(xx, cumulative + yy, lw=lw, label=symbol, color=color)
3259                    ax.fill_between(xx, cumulative, cumulative + yy, facecolor=color, alpha=alpha)
3260                    cumulative += yy
3261
3262            # Add Total PHDOS
3263            ax.plot(xx, self.phdos.values / factor, lw=lw, label="Total PHDOS", color='black')
3264            ax.legend(loc="best", fontsize=fontsize, shadow=True)
3265
3266        return fig
3267
3268    @add_fig_kwargs
3269    def plot_pjdos_cartdirs_site(self, view="inequivalent", units="eV", stacked=True, colormap="jet", alpha=0.7,
3270                                 xlims=None, ylims=None, ax_list=None, fontsize=8, verbose=0, **kwargs):
3271        """
3272        Plot phonon PJDOS for each atom in the unit cell. By default, only "inequivalent" atoms are shown.
3273
3274        Args:
3275            view: "inequivalent" to show only inequivalent atoms. "all" for all sites.
3276            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
3277                Case-insensitive.
3278            stacked: True if DOS partial contributions should be stacked on top of each other.
3279            colormap: matplotlib colormap.
3280            alpha: The alpha blending value, between 0 (transparent) and 1 (opaque)
3281            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
3282                   or scalar e.g. ``left``. If left (right) is None, default values are used.
3283            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
3284                   or scalar e.g. ``left``. If left (right) is None, default values are used
3285            ax_list: List of |matplotlib-Axes| or None if a new figure should be created.
3286            fontsize: Legend and title fontsize.
3287            verbose: Verbosity level.
3288
3289        Returns: |matplotlib-Figure|
3290        """
3291        # Define num_plots and ax2atom depending on view.
3292        factor = abu.phfactor_ev2units(units)
3293        #natom, ntypat = len(self.structure), self.structure.ntypesp
3294        lw = kwargs.pop("lw", 2)
3295
3296        # Select atoms.
3297        aview = self._get_atomview(view, verbose=verbose)
3298
3299        # Three rows for each cartesian direction.
3300        # Each row shows the contribution of each site + Total PH DOS.
3301        nrows, ncols = 3, 1
3302        ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
3303                                                sharex=False, sharey=True, squeeze=True)
3304        ax_list = np.reshape(ax_list, (nrows, ncols)).ravel()
3305        cmap = plt.get_cmap(colormap)
3306
3307        # [natom, three, nomega] array with PH-DOS projected over atoms and cartesian directions
3308        pjdos_atdir = self.reader.read_pjdos_atdir()
3309
3310        xx = self.phdos.mesh * factor
3311        for idir, ax in enumerate(ax_list):
3312            ax.grid(True)
3313            set_axlims(ax, xlims, "x")
3314            set_axlims(ax, ylims, "y")
3315
3316            ax.set_ylabel(r'PJDOS along %s' % {0: "x", 1: "y", 2: "z"}[idir])
3317            if idir == 2:
3318                ax.set_xlabel('Frequency %s' % abu.phunit_tag(units))
3319
3320            # Plot Type projected DOSes along cartesian direction idir
3321            cumulative = np.zeros(len(self.wmesh))
3322            for iatom in aview.iatom_list:
3323                site = self.structure[iatom]
3324                symbol = str(site)
3325                color = cmap(float(iatom) / max((len(aview.iatom_list) - 1), 1))
3326                yy = pjdos_atdir[iatom, idir] / factor
3327
3328                if not stacked:
3329                    ax.plot(xx, yy, label=symbol, color=color)
3330                else:
3331                    ax.plot(xx, cumulative + yy, lw=lw, label=symbol, color=color)
3332                    ax.fill_between(xx, cumulative, cumulative + yy, facecolor=color, alpha=alpha)
3333                    cumulative += yy
3334
3335            # Add Total PHDOS
3336            ax.plot(xx, self.phdos.values / factor, lw=lw, label="Total PHDOS", color='black')
3337            ax.legend(loc="best", fontsize=fontsize, shadow=True)
3338
3339        return fig
3340
3341    def yield_figs(self, **kwargs):  # pragma: no cover
3342        """
3343        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
3344        Used in abiview.py to get a quick look at the results.
3345        """
3346        units = kwargs.get("units", "mev")
3347        yield self.phdos.plot(units=units, show=False)
3348        yield self.plot_pjdos_type(units=units, show=False)
3349        # Old formats do not have MSQDOS arrays.
3350        try:
3351            msqd_dos = self.msqd_dos
3352        except Exception:
3353            msqd_dos = None
3354        if msqd_dos is not None:
3355            yield msqd_dos.plot(units=units, show=False)
3356            yield msqd_dos.plot_tensor(show=False)
3357
3358    def write_notebook(self, nbpath=None):
3359        """
3360        Write a jupyter_ notebook to nbpath. If ``nbpath`` is None, a temporay file in the current
3361        working directory is created. Return path to the notebook.
3362        """
3363        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
3364
3365        nb.cells.extend([
3366            nbv.new_code_cell("ncfile = abilab.abiopen('%s')" % self.filepath),
3367            nbv.new_code_cell("print(ncfile)"),
3368            nbv.new_code_cell("ncfile.phdos.plot();"),
3369            nbv.new_code_cell("ncfile.plot_pjdos_type();"),
3370            nbv.new_code_cell("ncfile.plot_pjdos_cartdirs_type(units='meV', stacked=True);"),
3371            nbv.new_code_cell("ncfile.plot_pjdos_cartdirs_site(view='inequivalent', units='meV', stacked=True);"),
3372            # TODO
3373            #msqd_dos = self.msqd_dos
3374            #msqd_dos.plot(units=self.units, show=False)
3375            #msqd_dos.plot_tensor(show=False)
3376        ])
3377
3378        return self._write_nb_nbpath(nb, nbpath)
3379
3380    def to_pymatgen(self):
3381        """
3382        Creates a pymatgen :class:`PmgCompletePhononDos` object.
3383        """
3384        total_dos = self.phdos.to_pymatgen()
3385
3386        # [natom, three, nomega] array with PH-DOS projected over atoms and cartesian directions"""
3387        pjdos_atdir = self.reader.read_pjdos_atdir()
3388
3389        factor = abu.phfactor_ev2units("thz")
3390        summed_pjdos = np.sum(pjdos_atdir, axis=1) / factor
3391
3392        pdoss = {site: pdos for site, pdos in zip(self.structure, summed_pjdos)}
3393
3394        return PmgCompletePhononDos(self.structure, total_dos, pdoss)
3395
3396
3397# FIXME: Remove. Use PhononBandsPlotter API.
3398@add_fig_kwargs
3399def phbands_gridplot(phb_objects, titles=None, phdos_objects=None, phdos_kwargs=None,
3400                     units="eV", width_ratios=(2, 1), fontsize=8, **kwargs):
3401    """
3402    Plot multiple phonon bandstructures and optionally DOSes on a grid.
3403
3404    Args:
3405        phb_objects: List of objects from which the phonon band structures are extracted.
3406            Each item in phb_objects is either a string with the path of the netcdf file,
3407            or one of the abipy object with an ``phbands`` attribute or a |PhononBands| object.
3408        phdos_objects: List of objects from which the phonon DOSes are extracted.
3409            Accept filepaths or |PhononDos| objects. If phdos_objects is not None,
3410            each subplot in the grid contains a band structure with DOS else a simple bandstructure plot.
3411        titles: List of strings with the titles to be added to the subplots.
3412        phdos_kwargs: optional dictionary with the options passed to ``get_phdos`` to compute the phonon DOS.
3413            Used only if ``phdos_objects`` is not None.
3414        units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
3415            Case-insensitive.
3416        width_ratios: Ratio between the width of the phonon band plots and the DOS plots.
3417            Used if `phdos_objects` is not None
3418        fontsize: legend and title fontsize.
3419
3420    Returns: |matplotlib-Figure|
3421    """
3422    # Build list of PhononBands objects.
3423    phbands_list = [PhononBands.as_phbands(obj) for obj in phb_objects]
3424
3425    # Build list of PhononDos objects.
3426    phdos_list = []
3427    if phdos_objects is not None:
3428        if phdos_kwargs is None: phdos_kwargs = {}
3429        phdos_list = [PhononDos.as_phdos(obj, phdos_kwargs) for obj in phdos_objects]
3430        if len(phdos_list) != len(phbands_list):
3431            raise ValueError("The number of objects for DOS must equal be to the number of bands")
3432
3433    import matplotlib.pyplot as plt
3434    nrows, ncols = 1, 1
3435    numeb = len(phbands_list)
3436    if numeb > 1:
3437        ncols = 2
3438        nrows = numeb // ncols + numeb % ncols
3439
3440    if not phdos_list:
3441        # Plot grid with phonon bands only.
3442        fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, sharey=True, squeeze=False)
3443        ax_list = ax_list.ravel()
3444        # don't show the last ax if numeb is odd.
3445        if numeb % ncols != 0: ax_list[-1].axis("off")
3446
3447        for i, (phbands, ax) in enumerate(zip(phbands_list, ax_list)):
3448            phbands.plot(ax=ax, units=units, show=False)
3449            if titles is not None: ax.set_title(titles[i], fontsize=fontsize)
3450            if i % ncols != 0:
3451                ax.set_ylabel("")
3452
3453    else:
3454        # Plot grid with phonon bands + DOS
3455        # see http://matplotlib.org/users/gridspec.html
3456        from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
3457        fig = plt.figure()
3458        gspec = GridSpec(nrows, ncols)
3459
3460        for i, (phbands, phdos) in enumerate(zip(phbands_list, phdos_list)):
3461            subgrid = GridSpecFromSubplotSpec(1, 2, subplot_spec=gspec[i], width_ratios=width_ratios, wspace=0.05)
3462            # Get axes and align bands and DOS.
3463            ax1 = plt.subplot(subgrid[0])
3464            ax2 = plt.subplot(subgrid[1], sharey=ax1)
3465            phbands.plot_with_phdos(phdos, ax_list=(ax1, ax2), units=units, show=False)
3466
3467            if titles is not None: ax1.set_title(titles[i], fontsize=fontsize)
3468            if i % ncols != 0:
3469                for ax in (ax1, ax2):
3470                    ax.set_ylabel("")
3471
3472    return fig
3473
3474
3475def dataframe_from_phbands(phbands_objects, index=None, with_spglib=True):
3476    """
3477    Build pandas dataframe with the most important results available in a list of band structures.
3478
3479    Args:
3480        phbands_objects: List of objects that can be converted to phonon bands objects..
3481            Support netcdf filenames or |PhononBands| objects
3482            See ``PhononBands.as_phbands`` for the complete list.
3483        index: Index of the dataframe.
3484        with_spglib: If True, spglib is invoked to get the spacegroup symbol and number.
3485
3486    Return: |pandas-DataFrame|
3487    """
3488    phbands_list = [PhononBands.as_phbands(obj) for obj in phbands_objects]
3489    # Use OrderedDict to have columns ordered nicely.
3490    odict_list = [(phbands.get_dict4pandas(with_spglib=with_spglib)) for phbands in phbands_list]
3491
3492    import pandas as pd
3493    return pd.DataFrame(odict_list, index=index,
3494                        columns=list(odict_list[0].keys()) if odict_list else None)
3495
3496
3497class PhononBandsPlotter(NotebookWriter):
3498    """
3499    Class for plotting phonon band structure and DOSes.
3500    Supports plots on the same graph or separated plots.
3501
3502    Usage example:
3503
3504    .. code-block:: python
3505
3506        plotter = PhononBandsPlotter()
3507        plotter.add_phbands("foo bands", "foo_PHBST.nc")
3508        plotter.add_phbands("bar bands", "bar_PHBST.nc")
3509        plotter.gridplot()
3510    """
3511    # Used in iter_lineopt to generate matplotlib linestyles.
3512    _LINE_COLORS = ["b", "r", "g", "m", "y", "k"]
3513    _LINE_STYLES = ["-", ":", "--", "-.",]
3514    _LINE_WIDTHS = [2, ]
3515
3516    def __init__(self, key_phbands=None, key_phdos=None, phdos_kwargs=None):
3517        """
3518        Args:
3519            key_phbands: List of (label, phbands) tuples.
3520                phbands is any object that can be converted into |PhononBands| e.g. ncfile, path.
3521            key_phdos: List of (label, phdos) tuples.
3522                phdos is any object that can be converted into |PhononDos|.
3523        """
3524        if key_phbands is None: key_phbands = []
3525        key_phbands = [(k, PhononBands.as_phbands(v)) for k, v in key_phbands]
3526        self._bands_dict = OrderedDict(key_phbands)
3527
3528        if key_phdos is None: key_phdos = []
3529        key_phdos = [(k, PhononDos.as_phdos(v, phdos_kwargs)) for k, v in key_phdos]
3530        self._phdoses_dict = OrderedDict(key_phdos)
3531        if key_phdos:
3532            if not key_phbands:
3533                raise ValueError("key_phbands must be specifed when key_dos is not None")
3534            if len(key_phbands) != len(key_phdos):
3535                raise ValueError("key_phbands and key_phdos must have the same number of elements.")
3536
3537    def __repr__(self):
3538        """Invoked by repr"""
3539        return self.to_string(func=repr)
3540
3541    def __str__(self):
3542        """Invoked by str"""
3543        return self.to_string(func=str)
3544
3545    def add_plotter(self, other):
3546        """Merge two plotters, return new plotter."""
3547        if not isinstance(other, self.__class__):
3548            raise TypeError("Don't know to to add %s to %s" % (other.__class__, self.__class__))
3549
3550        key_phbands = list(self._bands_dict.items()) + list(other._bands_dict.items())
3551        key_phdos = list(self._phdoses_dict.items()) + list(other._phdoses_dict.items())
3552
3553        return self.__class__(key_phbands=key_phbands, key_phdos=key_phdos)
3554
3555    def to_string(self, func=str, verbose=0):
3556        """String representation."""
3557        lines = []
3558        app = lines.append
3559        for i, (label, phbands) in enumerate(self.phbands_dict.items()):
3560            app("[%d] %s --> %s" % (i, label, func(phbands)))
3561
3562        if self.phdoses_dict:
3563            for i, (label, phdos) in enumerate(self.phdoses_dict.items()):
3564                app("[%d] %s --> %s" % (i, label, func(phdos)))
3565
3566        return "\n".join(lines)
3567
3568    def has_same_formula(self):
3569        """
3570        True of plotter contains structures with same chemical formula.
3571        """
3572        structures = [phbands.structure for phbands in self.phbands_dict.values()]
3573        if structures and any(s.formula != structures[0].formula for s in structures): return False
3574        return True
3575
3576    def get_phbands_frame(self, with_spglib=True):
3577        """
3578        Build a |pandas-DataFrame| with the most important results available in the band structures.
3579        """
3580        return dataframe_from_phbands(list(self.phbands_dict.values()),
3581                                      index=list(self.phbands_dict.keys()), with_spglib=with_spglib)
3582
3583    @property
3584    def phbands_dict(self):
3585        """Dictionary with the mapping label --> phbands."""
3586        return self._bands_dict
3587
3588    # TODO: Just an alias. To be removed in 0.4
3589    bands_dict = phbands_dict
3590
3591    @property
3592    def phdoses_dict(self):
3593        """Dictionary with the mapping label --> phdos."""
3594        return self._phdoses_dict
3595
3596    @property
3597    def phbands_list(self):
3598        """"List of |PhononBands| objects."""
3599        return list(self._bands_dict.values())
3600
3601    @property
3602    def phdoses_list(self):
3603        """"List of |PhononDos|."""
3604        return list(self._phdoses_dict.values())
3605
3606    def iter_lineopt(self):
3607        """Generates matplotlib linestyles."""
3608        for o in itertools.product(self._LINE_WIDTHS,  self._LINE_STYLES, self._LINE_COLORS):
3609            yield {"linewidth": o[0], "linestyle": o[1], "color": o[2]}
3610
3611    def add_phbands(self, label, bands, phdos=None, dos=None, phdos_kwargs=None):
3612        """
3613        Adds a band structure for plotting.
3614
3615        Args:
3616            label: label for the bands. Must be unique.
3617            bands: |PhononBands| object.
3618            phdos: |PhononDos| object.
3619            phdos_kwargs: optional dictionary with the options passed to ``get_phdos`` to compute the phonon DOS.
3620              Used only if ``phdos`` is not None.
3621        """
3622        if dos is not None:
3623            warnings.warn("dos has been renamed phdos. The argument will removed in abipy 0.4")
3624            if phdos is not None:
3625                raise ValueError("phdos and dos are mutually exclusive")
3626            phdos = dos
3627
3628        if label in self._bands_dict:
3629            raise ValueError("label %s is already in %s" % (label, list(self._bands_dict.keys())))
3630
3631        self._bands_dict[label] = PhononBands.as_phbands(bands)
3632
3633        if phdos is not None:
3634            self.phdoses_dict[label] = PhononDos.as_phdos(phdos, phdos_kwargs)
3635
3636    @add_fig_kwargs
3637    def combiplot(self, qlabels=None, units='eV', ylims=None, width_ratios=(2, 1), fontsize=8,
3638                  linestyle_dict=None, **kwargs):
3639        r"""
3640        Plot the band structure and the DOS on the same figure.
3641        Use ``gridplot`` to plot band structures on different figures.
3642
3643        Args:
3644            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
3645                Case-insensitive.
3646            qlabels: dictionary whose keys are tuples with the reduced coordinates of the k-points.
3647                The values are the labels e.g. ``klabels = {(0.0,0.0,0.0): "$\Gamma$", (0.5,0,0): "L"}``.
3648            ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
3649                   or scalar e.g. ``left``. If left (right) is None, default values are used
3650            width_ratios: Ratio between the width of the phonon bands plots and the DOS plots.
3651                Used if plotter has DOSes.
3652            fontsize: fontsize for titles and legend.
3653            linestyle_dict: Dictionary mapping labels to matplotlib linestyle options.
3654
3655        Returns: |matplotlib-Figure|
3656        """
3657        import matplotlib.pyplot as plt
3658        from matplotlib.gridspec import GridSpec
3659
3660        # Build grid of plots.
3661        fig = plt.figure()
3662        if self.phdoses_dict:
3663            gspec = GridSpec(1, 2, width_ratios=width_ratios, wspace=0.05)
3664            ax1 = plt.subplot(gspec[0])
3665            # Align bands and DOS.
3666            ax2 = plt.subplot(gspec[1], sharey=ax1)
3667            ax_list = [ax1, ax2]
3668
3669        else:
3670            ax1 = fig.add_subplot(111)
3671            ax_list = [ax1]
3672
3673        for ax in ax_list:
3674            ax.grid(True)
3675
3676        if ylims is not None:
3677            for ax in ax_list:
3678                set_axlims(ax, ylims, "y")
3679
3680        # Plot phonon bands.
3681        lines, legends = [], []
3682        my_kwargs, opts_label = kwargs.copy(), {}
3683        i = -1
3684        nqpt_list = [phbands.nqpt for phbands in self._bands_dict.values()]
3685        if any(nq != nqpt_list[0] for nq in nqpt_list):
3686            cprint("WARNING combiblot: Bands have different number of k-points:\n%s" % str(nqpt_list), "yellow")
3687
3688        for (label, phbands), lineopt in zip(self._bands_dict.items(), self.iter_lineopt()):
3689            i += 1
3690            if linestyle_dict is not None and label in linestyle_dict:
3691                my_kwargs.update(linestyle_dict[label])
3692            else:
3693                my_kwargs.update(lineopt)
3694            opts_label[label] = my_kwargs.copy()
3695
3696            l = phbands.plot_ax(ax1, branch=None, units=units, **my_kwargs)
3697            lines.append(l[0])
3698
3699            # Use relative paths if label is a file.
3700            if os.path.isfile(label):
3701                legends.append("%s" % os.path.relpath(label))
3702            else:
3703                legends.append("%s" % label)
3704
3705            # Set ticks and labels, legends.
3706            if i == 0:
3707                phbands.decorate_ax(ax1, qlabels=qlabels, units=units)
3708
3709        ax1.legend(lines, legends, loc='best', fontsize=fontsize, shadow=True)
3710
3711        # Add DOSes
3712        if self.phdoses_dict:
3713            ax = ax_list[1]
3714            for label, dos in self.phdoses_dict.items():
3715                dos.plot_dos_idos(ax, exchange_xy=True, units=units, **opts_label[label])
3716
3717        return fig
3718
3719    def plot(self, *args, **kwargs):
3720        """An alias for combiplot."""
3721        return self.combiplot(*args, **kwargs)
3722
3723    def yield_figs(self, **kwargs):  # pragma: no cover
3724        """This function *generates* a predefined list of matplotlib figures with minimal input from the user."""
3725        yield self.gridplot(show=False)
3726        yield self.boxplot(show=False)
3727        if self.has_same_formula():
3728            yield self.combiplot(show=False)
3729            yield self.combiboxplot(show=False)
3730
3731    @add_fig_kwargs
3732    def gridplot(self, with_dos=True, units="eV", fontsize=8, **kwargs):
3733        """
3734        Plot multiple phonon bandstructures and optionally DOSes on a grid.
3735
3736        Args:
3737            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
3738                Case-insensitive.
3739            with_dos: True to plot phonon DOS (if available).
3740            fontsize: legend and title fontsize.
3741
3742        Returns: |matplotlib-Figure|
3743        """
3744        titles = list(self._bands_dict.keys())
3745        phb_objects = list(self._bands_dict.values())
3746        phdos_objects = None
3747        if self.phdoses_dict and with_dos:
3748            phdos_objects = list(self.phdoses_dict.values())
3749
3750        return phbands_gridplot(phb_objects, titles=titles, phdos_objects=phdos_objects,
3751                                units=units, fontsize=fontsize, show=False)
3752
3753    @add_fig_kwargs
3754    def gridplot_with_hue(self, hue, with_dos=False, units="eV", width_ratios=(2, 1),
3755                          ylims=None, fontsize=8, **kwargs):
3756        """
3757        Plot multiple phonon bandstructures and optionally DOSes on a grid.
3758        Group results by ``hue``.
3759
3760        Example:
3761
3762            plotter.gridplot_with_hue("tsmear")
3763
3764        Args:
3765            hue: Variable that define subsets of the phonon bands, which will be drawn on separate plots.
3766                Accepts callable or string
3767                If string, it's assumed that the phbands has an attribute with the same name and getattr is invoked.
3768                Dot notation is also supported e.g. hue="structure.formula" --> abifile.structure.formula
3769                If callable, the output of hue(phbands) is used.
3770            with_dos: True to plot phonon DOS (if available).
3771            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
3772                Case-insensitive.
3773            width_ratios: Ratio between the width of the fatbands plots and the DOS plots.
3774                Used if plotter has PH DOSes is not None
3775            ylims: Set the data limits for the y-axis. Accept tuple e.g. `(left, right)`
3776                or scalar e.g. `left`. If left (right) is None, default values are used
3777            fontsize: legend and title fontsize.
3778
3779        Returns: |matplotlib-Figure|
3780        """
3781        # Extract all quantities available in the plotter to prepare grouping.
3782        all_labels = list(self._bands_dict.keys())
3783        all_phb_objects = list(self._bands_dict.values())
3784        all_phdos_objects = None
3785        if self.phdoses_dict and with_dos:
3786            all_phdos_objects = list(self.phdoses_dict.values())
3787
3788        # Need index to handle all_phdos_objects if DOSes are wanted.
3789        if callable(hue):
3790            items = [(hue(phb), phb, i, label) for i, (phb, label) in enumerate(zip(all_phb_objects, all_labels))]
3791        else:
3792            # Assume string. Either phbands.hue or phbands.params[hue].
3793            if duck.hasattrd(all_phb_objects[0], hue):
3794                items = [(duck.getattrd(phb, hue), phb, i, label)
3795                        for i, (phb, label) in enumerate(zip(all_phb_objects, all_labels))]
3796            else:
3797                items = [(phb.params[hue], phb, i, label)
3798                        for i, (phb, label) in enumerate(zip(all_phb_objects, all_labels))]
3799
3800        # Group items by hue value.
3801        hvalues, groups = sort_and_groupby(items, key=lambda t: t[0], ret_lists=True)
3802        nrows, ncols = len(groups), 1
3803
3804        if not all_phdos_objects:
3805            # Plot grid with phonon bands only.
3806            ax_phbands, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
3807                                                       sharex=True, sharey=True, squeeze=False)
3808            ax_phbands = ax_phbands.ravel()
3809
3810            # Loop over groups
3811            for ax, hvalue, grp in zip(ax_phbands, hvalues, groups):
3812                # Unzip items
3813                # See https://stackoverflow.com/questions/19339/transpose-unzip-function-inverse-of-zip
3814                _, phb_list, indices, labels = tuple(map(list, zip(*grp)))
3815                assert len(phb_list) == len(indices) and len(phb_list) == len(labels)
3816                ax.grid(True)
3817                sh = str(hue) if not callable(hue) else str(hue.__doc__)
3818                ax.set_title("%s = %s" % (sh, hvalue), fontsize=fontsize)
3819
3820                nqpt_list = [phbands.nqpt for phbands in phb_list]
3821                if any(nq != nqpt_list[0] for nq in nqpt_list):
3822                    cprint("WARNING: Bands have different number of k-points:\n%s" % str(nqpt_list), "yellow")
3823
3824                # Plot all bands in grups on the same axis.
3825                for i, (phbands, lineopts) in enumerate(zip(phb_list, self.iter_lineopt())):
3826                    # Plot all branches with lineopts and set the label of the last line produced.
3827                    phbands.plot_ax(ax, branch=None, units=units, **lineopts)
3828                    ax.lines[-1].set_label(labels[i])
3829
3830                    if i == 0:
3831                        # Set ticks and labels
3832                        phbands.decorate_ax(ax, qlabels=None, units=units)
3833
3834                # Set legends.
3835                ax.legend(loc='best', fontsize=fontsize, shadow=True)
3836                set_axlims(ax, ylims, "y")
3837
3838        else:
3839            # Plot grid with phonon bands + DOS (grouped by hue)
3840            # see http://matplotlib.org/users/gridspec.html
3841            import matplotlib.pyplot as plt
3842            from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
3843            fig = plt.figure()
3844            gspec = GridSpec(nrows, ncols)
3845
3846            # Loop over groups
3847            for i, (hvalue, grp) in enumerate(zip(hvalues, groups)):
3848                # Unzip items
3849                _, phb_list, indices, labels = tuple(map(list, zip(*grp)))
3850                assert len(phb_list) == len(indices) and len(phb_list) == len(labels)
3851
3852                subgrid = GridSpecFromSubplotSpec(1, 2, subplot_spec=gspec[i], width_ratios=width_ratios, wspace=0.05)
3853                # Get axes and align bands and DOS.
3854                ax1 = plt.subplot(subgrid[0])
3855                ax2 = plt.subplot(subgrid[1], sharey=ax1)
3856
3857                sh = str(hue) if not callable(hue) else str(hue.__doc__)
3858                ax1.set_title("%s = %s" % (sh, hvalue), fontsize=fontsize)
3859
3860                # Plot all bands in grups on the same axis.
3861                nqpt_list = [phbands.nqpt for phbands in phb_list]
3862                if any(nq != nqpt_list[0] for nq in nqpt_list):
3863                    cprint("WARNING: Bands have different number of k-points:\n%s" % str(nqpt_list), "yellow")
3864
3865                phdos_list = [all_phdos_objects[j] for j in indices]
3866                for j, (phbands, phdos, lineopts) in enumerate(zip(phb_list, phdos_list, self.iter_lineopt())):
3867                    # Plot all branches with DOS and lineopts and set the label of the last line produced
3868                    phbands.plot_with_phdos(phdos, ax_list=(ax1, ax2), units=units, show=False, **lineopts)
3869                    ax1.lines[-1].set_label(labels[j])
3870
3871                # Set legends on ax1
3872                ax1.legend(loc='best', fontsize=fontsize, shadow=True)
3873
3874                for ax in (ax1, ax2):
3875                    set_axlims(ax, ylims, "y")
3876
3877        return fig
3878
3879    @add_fig_kwargs
3880    def boxplot(self, mode_range=None, units="eV", swarm=False, **kwargs):
3881        """
3882        Use seaborn_ to draw a box plot to show distributions of eigenvalues with respect to the band index.
3883        Band structures are drawn on different subplots.
3884
3885        Args:
3886            mode_range: Only bands such as ``mode_range[0] <= nu_index < mode_range[1]`` are included in the plot.
3887            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
3888                Case-insensitive.
3889            swarm: True to show the datapoints on top of the boxes
3890            kwargs: Keywork arguments passed to seaborn_ boxplot.
3891        """
3892        # Build grid of plots.
3893        num_plots, ncols, nrows = len(self.phbands_dict), 1, 1
3894        if num_plots > 1:
3895            ncols = 2
3896            nrows = (num_plots // ncols) + (num_plots % ncols)
3897
3898        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
3899                                                sharex=False, sharey=False, squeeze=False)
3900        ax_list = ax_list.ravel()
3901
3902        # don't show the last ax if numeb is odd.
3903        if num_plots % ncols != 0: ax_list[-1].axis("off")
3904
3905        for (label, phbands), ax in zip(self.phbands_dict.items(), ax_list):
3906            phbands.boxplot(ax=ax, units=units, mode_range=mode_range, show=False)
3907            ax.set_title(label)
3908
3909        return fig
3910
3911    @add_fig_kwargs
3912    def combiboxplot(self, mode_range=None, units="eV", swarm=False, ax=None, **kwargs):
3913        """
3914        Use seaborn_ to draw a box plot comparing the distributions of the frequencies.
3915        Phonon Band structures are drawn on the same plot.
3916
3917        Args:
3918            mode_range: Only bands such as ``mode_range[0] <= nu_index < mode_range[1]`` are included in the plot.
3919            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
3920                Case-insensitive.
3921            swarm: True to show the datapoints on top of the boxes
3922            ax: |matplotlib-Axes| or None if a new figure should be created.
3923            kwargs: Keyword arguments passed to seaborn_ boxplot.
3924        """
3925        frames = []
3926        for label, phbands in self.phbands_dict.items():
3927            # Get the dataframe, select bands and add column with label
3928            frame = phbands.get_dataframe()
3929            if mode_range is not None:
3930                frame = frame[(frame["mode"] >= mode_range[0]) & (frame["mode"] < mode_range[1])]
3931            frame["label"] = label
3932            frames.append(frame)
3933
3934        # Merge frames ignoring index (not meaningful here)
3935        import pandas as pd
3936        data = pd.concat(frames, ignore_index=True)
3937
3938        ax, fig, plt = get_ax_fig_plt(ax=ax)
3939        ax.grid(True)
3940
3941        # Create column with frequencies in `units`.
3942        factor = abu.phfactor_ev2units(units)
3943        yname = "freq %s" % abu.phunit_tag(units)
3944        data[yname] = factor * data["freq"]
3945
3946        import seaborn as sns
3947        sns.boxplot(x="mode", y=yname, data=data, hue="label", ax=ax, **kwargs)
3948        if swarm:
3949            sns.swarmplot(x="mode", y=yname, data=data, hue="label", color=".25", ax=ax)
3950
3951        return fig
3952
3953    @add_fig_kwargs
3954    def plot_phdispl(self, qpoint, **kwargs):
3955        """
3956        Plot vertical bars with the contribution of the different atomic types to the phonon displacements
3957        at a given q-point. One panel for all |PhononBands| stored in the plotter.
3958
3959        Args:
3960            qpoint: integer, vector of reduced coordinates or |Kpoint| object.
3961            kwargs: keyword arguments passed to phbands.plot_phdispl
3962
3963        Returns: |matplotlib-Figure|
3964        """
3965        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=len(self.phbands_dict), ncols=1,
3966                                                sharex=False, sharey=False, squeeze=False)
3967
3968        for i, (ax, (label, phbands)) in enumerate(zip(ax_list.ravel(), self.phbands_dict.items())):
3969            phbands.plot_phdispl(qpoint, cart_dir=None, ax=ax, show=False, **kwargs)
3970            # Disable artists.
3971            if i != 0:
3972                #set_visible(ax, False, "title")
3973                ax.set_title(label, fontsize=kwargs.get("fontsize", 8))
3974            if i != len(self.phbands_dict) - 1:
3975                set_visible(ax, False, "xlabel")
3976
3977        return fig
3978
3979    def animate(self, interval=500, savefile=None, units="eV", width_ratios=(2, 1), show=True):
3980        """
3981        Use matplotlib to animate a list of band structure plots (with or without DOS).
3982
3983        Args:
3984            interval: draws a new frame every interval milliseconds.
3985            savefile: Use e.g. 'myanimation.mp4' to save the animation in mp4 format.
3986            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
3987                Case-insensitive.
3988            width_ratios: Ratio between the band structure plot and the dos plot.
3989                Used when there are DOS stored in the plotter.
3990            show: True if the animation should be shown immediately.
3991
3992        Returns: Animation object.
3993
3994        .. Seealso::
3995
3996            http://matplotlib.org/api/animation_api.html
3997            http://jakevdp.github.io/blog/2012/08/18/matplotlib-animation-tutorial/
3998
3999        .. Note::
4000
4001            It would be nice to animate the title of the plot, unfortunately
4002            this feature is not available in the present version of matplotlib.
4003            See: http://stackoverflow.com/questions/17558096/animated-title-in-matplotlib
4004        """
4005        phbands_list, phdos_list = self.phbands_list, self.phdoses_list
4006        if phdos_list and len(phdos_list) != len(phbands_list):
4007            raise ValueError("The number of objects for DOS must be equal to the number of bands")
4008        #titles = list(self.phbands_dict.keys())
4009
4010        import matplotlib.pyplot as plt
4011        fig = plt.figure()
4012        plotax_kwargs = {"color": "black", "linewidth": 2.0}
4013
4014        artists = []
4015        if not phdos_list:
4016            # Animation with band structures
4017            ax = fig.add_subplot(1, 1, 1)
4018            phbands_list[0].decorate_ax(ax, units=units)
4019            for i, phbands in enumerate(phbands_list):
4020                lines = phbands.plot_ax(ax=ax, branch=None, units=units, **plotax_kwargs)
4021                #if titles is not None: lines += [ax.set_title(titles[i])]
4022                artists.append(lines)
4023        else:
4024            # Animation with band structures + DOS.
4025            from matplotlib.gridspec import GridSpec
4026            gspec = GridSpec(1, 2, width_ratios=width_ratios, wspace=0.05)
4027            ax1 = plt.subplot(gspec[0])
4028            ax2 = plt.subplot(gspec[1], sharey=ax1)
4029            phbands_list[0].decorate_ax(ax1)
4030            ax2.grid(True)
4031            ax2.yaxis.set_ticks_position("right")
4032            ax2.yaxis.set_label_position("right")
4033
4034            for i, (phbands, phdos) in enumerate(zip(phbands_list, phdos_list)):
4035                phbands_lines = phbands.plot_ax(ax=ax1, branch=None, units=units, **plotax_kwargs)
4036                phdos_lines = phdos.plot_dos_idos(ax=ax2, units=units, exchange_xy=True, **plotax_kwargs)
4037                lines = phbands_lines + phdos_lines
4038                #if titles is not None: lines += [ax.set_title(titles[i])]
4039                artists.append(lines)
4040
4041        import matplotlib.animation as animation
4042        anim = animation.ArtistAnimation(fig, artists, interval=interval,
4043                                         blit=False, # True is faster but then the movie starts with an empty frame!
4044                                         #repeat_delay=1000
4045                                         )
4046
4047        if savefile is not None: anim.save(savefile)
4048        if show: plt.show()
4049
4050        return anim
4051
4052    def ipw_select_plot(self): # pragma: no cover
4053        """
4054        Return an ipython widget with controllers to select the plot.
4055        """
4056        def plot_callback(plot_type, units):
4057            r = getattr(self, plot_type)(units=units, show=True)
4058            if plot_type == "animate": return r
4059
4060        import ipywidgets as ipw
4061        return ipw.interact_manual(
4062                plot_callback,
4063                plot_type=["combiplot", "gridplot", "boxplot", "combiboxplot", "animate"],
4064                units=["eV", "cm-1", "Ha"],
4065            )
4066
4067    def _repr_html_(self):
4068        """Integration with jupyter_ notebooks."""
4069        return self.ipw_select_plot()
4070
4071    def get_panel(self):
4072        """Return tabs with widgets to interact with the |PhononBandsPlotter| file."""
4073        from abipy.panels.phonons import PhononBandsPlotterPanel
4074        return PhononBandsPlotterPanel(self).get_panel()
4075
4076    def write_notebook(self, nbpath=None):
4077        """
4078        Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current
4079        working directory is created. Return path to the notebook.
4080        """
4081        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
4082
4083        # Use pickle files for data persistence.
4084        tmpfile = self.pickle_dump()
4085
4086        nb.cells.extend([
4087            #nbv.new_markdown_cell("# This is a markdown cell"),
4088            nbv.new_code_cell("plotter = abilab.PhononBandsPlotter.pickle_load('%s')" % tmpfile),
4089            nbv.new_code_cell("print(plotter)"),
4090            nbv.new_code_cell("frame = plotter.get_phbands_frame()\ndisplay(frame)"),
4091            nbv.new_code_cell("plotter.ipw_select_plot()"),
4092        ])
4093
4094        return self._write_nb_nbpath(nb, nbpath)
4095
4096
4097class PhononDosPlotter(NotebookWriter):
4098    """
4099    Class for plotting multiple phonon DOSes.
4100
4101    Usage example:
4102
4103    .. code-block:: python
4104
4105        plotter = PhononDosPlotter()
4106        plotter.add_phdos("foo dos", "foo.nc")
4107        plotter.add_phdos("bar dos", "bar.nc")
4108        plotter.gridplot()
4109    """
4110    def __init__(self, key_phdos=None, phdos_kwargs=None):
4111        self._phdoses_dict = OrderedDict()
4112        if key_phdos is None: key_phdos = []
4113        for label, phdos in key_phdos:
4114            self.add_phdos(label, phdos, phdos_kwargs=phdos_kwargs)
4115
4116    @property
4117    def phdos_list(self):
4118        """List of phonon DOSes"""
4119        return list(self._phdoses_dict.values())
4120
4121    def add_phdos(self, label, phdos, phdos_kwargs=None):
4122        """
4123        Adds a DOS for plotting.
4124
4125        Args:
4126            label: label for the phonon DOS. Must be unique.
4127            phdos: |PhononDos| object.
4128            phdos_kwargs: optional dictionary with the options passed to `get_phdos` to compute the phonon DOS.
4129                Used when phdos is not already an instance of `cls` or when we have to compute the DOS from obj.
4130        """
4131        if label in self._phdoses_dict:
4132            raise ValueError("label %s is already in %s" % (label, list(self._phdoses_dict.keys())))
4133
4134        self._phdoses_dict[label] = PhononDos.as_phdos(phdos, phdos_kwargs)
4135
4136    #def has_same_formula(self):
4137    #    """
4138    #    True of plotter contains structures with same chemical formula.
4139    #    """
4140    #    structures = [phdos.structure for phdos in self._phdoses_dict.values()]
4141    #    if structures and any(s.formula != structures[0].formula for s in structures): return False
4142    #    return True
4143
4144    @add_fig_kwargs
4145    def combiplot(self, ax=None, units="eV", xlims=None, ylims=None, fontsize=8, **kwargs):
4146        """
4147        Plot DOSes on the same figure. Use ``gridplot`` to plot DOSes on different figures.
4148
4149        Args:
4150            ax: |matplotlib-Axes| or None if a new figure should be created.
4151            units: Units for phonon plots. Possible values in ("eV", "meV", "Ha", "cm-1", "Thz").
4152                Case-insensitive.
4153            xlims: Set the data limits for the x-axis. Accept tuple e.g. `(left, right)`
4154                   or scalar e.g. `left`. If left (right) is None, default values are used
4155            ylims: y-axis limits.
4156            fontsize: Legend and title fontsize.
4157
4158        Returns: |matplotlib-Figure|
4159        """
4160        ax, fig, plt = get_ax_fig_plt(ax=ax)
4161        ax.grid(True)
4162        set_axlims(ax, xlims, "x")
4163        set_axlims(ax, ylims, "y")
4164        ax.set_xlabel('Energy %s' % abu.phunit_tag(units))
4165        ax.set_ylabel('DOS %s' % abu.phdos_label_from_units(units))
4166
4167        lines, legends = [], []
4168        for label, dos in self._phdoses_dict.items():
4169            l = dos.plot_dos_idos(ax, units=units, **kwargs)[0]
4170            lines.append(l)
4171            legends.append("DOS: %s" % label)
4172
4173        # Set legends.
4174        ax.legend(lines, legends, loc='best', fontsize=fontsize, shadow=True)
4175
4176        return fig
4177
4178    def plot(self, **kwargs):
4179        """An alias for combiplot."""
4180        return self.combiplot(**kwargs)
4181
4182    @add_fig_kwargs
4183    def gridplot(self, units="eV", xlims=None, ylims=None, fontsize=8, **kwargs):
4184        """
4185        Plot multiple DOSes on a grid.
4186
4187        Args:
4188            units: eV for energies in ev/unit_cell, Jmol for results in J/mole.
4189            xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
4190                   or scalar e.g. ``left``. If left (right) is None, default values are used
4191            fontsize: Legend and title fontsize.
4192
4193        Returns: |matplotlib-Figure|
4194        """
4195        titles = list(self._phdoses_dict.keys())
4196        phdos_list = list(self._phdoses_dict.values())
4197
4198        nrows, ncols = 1, 1
4199        numeb = len(phdos_list)
4200        if numeb > 1:
4201            ncols = 2
4202            nrows = numeb // ncols + numeb % ncols
4203
4204        # Build Grid
4205        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
4206                                                sharex=True, sharey=True, squeeze=False)
4207        ax_list = ax_list.ravel()
4208
4209        # don't show the last ax if numeb is odd.
4210        if numeb % ncols != 0: ax_list[-1].axis("off")
4211
4212        for i, (label, phdos) in enumerate(self._phdoses_dict.items()):
4213            ax = ax_list[i]
4214            phdos.plot_dos_idos(ax, units=units)
4215
4216            ax.set_xlabel('Energy %s' % abu.phunit_tag(units), fontsize=fontsize)
4217            ax.set_ylabel("DOS %s" % abu.phdos_label_from_units(units), fontsize=fontsize)
4218            ax.set_title(label, fontsize=fontsize)
4219            ax.grid(True)
4220            set_axlims(ax, xlims, "x")
4221            set_axlims(ax, ylims, "y")
4222            if i % ncols != 0:
4223                ax.set_ylabel("")
4224
4225        return fig
4226
4227    @add_fig_kwargs
4228    def plot_harmonic_thermo(self, tstart=5, tstop=300, num=50, units="eV", formula_units=1,
4229                             quantities="all", fontsize=8, **kwargs):
4230        """
4231        Plot thermodynamic properties from the phonon DOS within the harmonic approximation.
4232
4233        Args:
4234            tstart: The starting value (in Kelvin) of the temperature mesh.
4235            tstop: The end value (in Kelvin) of the mesh.
4236            num: int, optional Number of samples to generate. Default is 50.
4237            units: eV for energies in ev/unit_cell, Jmol for results in J/mole.
4238            formula_units: the number of formula units per unit cell. If unspecified, the
4239                thermodynamic quantities will be given on a per-unit-cell basis.
4240            quantities: List of strings specifying the thermodynamic quantities to plot.
4241                Possible values in ["internal_energy", "free_energy", "entropy", "c_v"].
4242            fontsize: Legend and title fontsize.
4243
4244        Returns: |matplotlib-Figure|
4245        """
4246        quantities = list_strings(quantities) if quantities != "all" else \
4247            ["internal_energy", "free_energy", "entropy", "cv"]
4248
4249        # Build grid of plots.
4250        ncols, nrows = 1, 1
4251        num_plots = len(quantities)
4252        if num_plots > 1:
4253            ncols = 2
4254            nrows = num_plots // ncols + num_plots % ncols
4255
4256        ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
4257                                               sharex=True, sharey=False, squeeze=False)
4258        # don't show the last ax if num_plots is odd.
4259        if num_plots % ncols != 0: ax_mat[-1, -1].axis("off")
4260
4261        for iax, (qname, ax) in enumerate(zip(quantities, ax_mat.flat)):
4262            for i, (label, phdos) in enumerate(self._phdoses_dict.items()):
4263                # Compute thermodynamic quantity associated to qname.
4264                f1d = getattr(phdos, "get_" + qname)(tstart=tstart, tstop=tstop, num=num)
4265                ys = f1d.values
4266                if formula_units != 1: ys /= formula_units
4267                if units == "Jmol": ys = ys * abu.e_Cb * abu.Avogadro
4268                ax.plot(f1d.mesh, ys, label=label)
4269
4270            ax.set_title(qname, fontsize=fontsize)
4271            ax.grid(True)
4272            ax.set_ylabel(_THERMO_YLABELS[qname][units], fontsize=fontsize)
4273            ax.set_xlabel("Temperature (K)", fontsize=fontsize)
4274            if iax == 0:
4275                ax.legend(loc="best", fontsize=fontsize, shadow=True)
4276
4277        return fig
4278
4279    def ipw_select_plot(self): # pragma: no cover
4280        """
4281        Return an ipython widget with controllers to select the plot.
4282        """
4283        def plot_callback(plot_type, units):
4284            getattr(self, plot_type)(units=units, show=True)
4285
4286        import ipywidgets as ipw
4287        return ipw.interact_manual(
4288                plot_callback,
4289                plot_type=["combiplot", "gridplot"],
4290                units=["eV", "meV", "cm-1", "Thz", "Ha"],
4291            )
4292
4293    def ipw_harmonic_thermo(self): # pragma: no cover
4294        """
4295        Return an ipython widget with controllers to plot thermodynamic properties
4296        from the phonon DOS within the harmonic approximation.
4297        """
4298        def plot_callback(tstart, tstop, num, units, formula_units):
4299            self.plot_harmonic_thermo(tstart=tstart, tstop=tstop, num=num,
4300                                      units=units, formula_units=formula_units, show=True)
4301
4302        import ipywidgets as ipw
4303        return ipw.interact_manual(
4304                plot_callback,
4305                tstart=5, tstop=300, num=50, units=["eV", "Jmol"], formula_units=1)
4306
4307    def yield_figs(self, **kwargs):  # pragma: no cover
4308        """
4309        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
4310        """
4311        yield self.gridplot(show=False)
4312        yield self.plot_harmonic_thermo(show=False)
4313        #if self.has_same_formula():
4314        yield self.combiplot(show=False)
4315
4316    def write_notebook(self, nbpath=None):
4317        """
4318        Write an jupyter notebook to nbpath. If nbpath is None, a temporay file in the current
4319        working directory is created. Return path to the notebook.
4320        """
4321        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
4322
4323        # Use pickle files for data persistence.
4324        tmpfile = self.pickle_dump()
4325
4326        nb.cells.extend([
4327            #nbv.new_markdown_cell("# This is a markdown cell"),
4328            nbv.new_code_cell("plotter = abilab.ElectronDosPlotter.pickle_load('%s')" % tmpfile),
4329            nbv.new_code_cell("print(plotter)"),
4330            nbv.new_code_cell("plotter.ipw_select_plot()"),
4331            nbv.new_code_cell("plotter.ipw_harmonic_thermo()"),
4332        ])
4333
4334        return self._write_nb_nbpath(nb, nbpath)
4335
4336
4337class RobotWithPhbands(object):
4338    """
4339    Mixin class for robots associated to files with |PhononBands|.
4340    """
4341    def combiplot_phbands(self, **kwargs):
4342        """Wraps combiplot method of |PhononBandsPlotter|. kwargs passed to combiplot."""
4343        return self.get_phbands_plotter().combiplot(**kwargs)
4344
4345    def gridplot_phbands(self, **kwargs):
4346        """Wraps gridplot method of |PhononBandsPlotter|. kwargs passed to gridplot."""
4347        return self.get_phbands_plotter().gridplot(**kwargs)
4348
4349    def boxplot_phbands(self, **kwargs):
4350        """Wraps boxplot method of |PhononBandsPlotter|. kwargs passed to boxplot."""
4351        return self.get_phbands_plotter().boxplot(**kwargs)
4352
4353    def combiboxplot_phbands(self, **kwargs):
4354        """Wraps combiboxplot method of |PhononBandsPlotter|. kwargs passed to combiboxplot."""
4355        return self.get_phbands_plotter().combiboxplot(**kwargs)
4356
4357    #def combiplot_phdos(self, **kwargs):
4358    #    """Wraps combiplot method of |ElectronDosPlotter|. kwargs passed to combiplot."""
4359    #    return self.get_phdos_plotter().combiplot(**kwargs)
4360    #
4361    #def gridplot_phdos(self, **kwargs):
4362    #    """Wraps gridplot method of |ElectronDosPlotter|. kwargs passed to gridplot."""
4363    #    return self.get_phdos_plotter().gridplot(**kwargs)
4364
4365    def get_phbands_plotter(self, filter_abifile=None, cls=None):
4366        """
4367        Build and return an instance of |PhononBandsPlotter| or a subclass is cls is not None.
4368
4369        Args:
4370            filter_abifile: Function that receives an ``abifile`` object and returns
4371                True if the file should be added to the plotter.
4372            cls: subclass of |PhononBandsPlotter|
4373        """
4374        plotter = PhononBandsPlotter() if cls is None else cls()
4375
4376        for label, abifile in self.items():
4377            if filter_abifile is not None and not filter_abifile(abifile): continue
4378            plotter.add_phbands(label, abifile.phbands)
4379
4380        return plotter
4381
4382    def get_phbands_dataframe(self, with_spglib=True):
4383        """
4384        Build a |pandas-dataframe| with the most important results available in the band structures.
4385        """
4386        return dataframe_from_phbands([nc.phbands for nc in self.abifiles],
4387                                      index=self.labels, with_spglib=with_spglib)
4388
4389    @add_fig_kwargs
4390    def plot_phdispl(self, qpoint, **kwargs):
4391        """
4392        Plot vertical bars with the contribution of the different atomic types to the phonon displacements
4393        at a given q-point. One panel for all phbands stored in the plotter.
4394
4395        Args:
4396            qpoint: integer, vector of reduced coordinates or |Kpoint| object.
4397            kwargs: keyword arguments passed to phbands.plot_phdispl
4398
4399        Returns: |matplotlib-Figure|
4400        """
4401        return self.get_phbands_plotter().plot_phdispl(qpoint, show=False, **kwargs)
4402
4403    def get_phbands_code_cells(self, title=None):
4404        """Return list of notebook cells."""
4405        # Try not pollute namespace with lots of variables.
4406        nbformat, nbv = self.get_nbformat_nbv()
4407        title = "## Code to compare multiple PhononBands objects" if title is None else str(title)
4408        return [
4409            nbv.new_markdown_cell(title),
4410            nbv.new_code_cell("robot.get_phbands_plotter().ipw_select_plot();"),
4411            nbv.new_code_cell("#robot.plot_phdispl(qpoint=(0, 0, 0));"),
4412        ]
4413
4414
4415# TODO: PhdosRobot
4416class PhbstRobot(Robot, RobotWithPhbands):
4417    """
4418    This robot analyzes the results contained in multiple PHBST.nc files.
4419
4420    .. rubric:: Inheritance Diagram
4421    .. inheritance-diagram:: PhbstRobot
4422    """
4423    EXT = "PHBST"
4424
4425    def yield_figs(self, **kwargs):  # pragma: no cover
4426        """
4427        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
4428        Used in abiview.py to get a quick look at the results.
4429        """
4430        plotter = self.get_phbands_plotter()
4431        for fig in plotter.yield_figs(): yield fig
4432
4433    def write_notebook(self, nbpath=None):
4434        """
4435        Write a jupyter_ notebook to nbpath. If ``nbpath`` is None, a temporay file in the current
4436        working directory is created. Return path to the notebook.
4437        """
4438        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
4439
4440        args = [(l, f.filepath) for l, f in self.items()]
4441        nb.cells.extend([
4442            #nbv.new_markdown_cell("# This is a markdown cell"),
4443            nbv.new_code_cell("robot = abilab.PhbstRobot(*%s)\nrobot.trim_paths()\nrobot" % str(args)),
4444        ])
4445
4446        # Mixins
4447        nb.cells.extend(self.get_baserobot_code_cells())
4448        nb.cells.extend(self.get_phbands_code_cells())
4449
4450        return self._write_nb_nbpath(nb, nbpath)
4451