1# coding: utf-8
2"""This module provides mixin classes"""
3import abc
4import os
5import collections
6import tempfile
7import pickle
8import numpy as np
9
10from time import ctime
11from monty.os.path import which
12from monty.termcolor import cprint
13from monty.string import list_strings
14from monty.collections import dict2namedtuple
15from monty.functools import lazy_property
16
17
18__all__ = [
19    "AbinitNcFile",
20    "Has_Structure",
21    "Has_ElectronBands",
22    "Has_PhononBands",
23    "NotebookWriter",
24    "Has_Header",
25]
26
27
28class BaseFile(metaclass=abc.ABCMeta):
29    """
30    Abstract base class defining the methods that must be implemented
31    by the concrete classes representing the different files produced by ABINIT.
32    """
33    def __init__(self, filepath):
34        self._filepath = os.path.abspath(filepath)
35
36        # Save stat values
37        stat = os.stat(filepath)
38        self._last_atime = stat.st_atime
39        self._last_mtime = stat.st_mtime
40        self._last_ctime = stat.st_ctime
41
42    def __repr__(self):
43        return "<%s, %s>" % (self.__class__.__name__, self.relpath)
44
45    @classmethod
46    def from_file(cls, filepath):
47        """Initialize the object from a string."""
48        if isinstance(filepath, cls): return filepath
49
50        #print("Perhaps the subclass", cls, "must redefine the classmethod from_file.")
51        return cls(filepath)
52
53    @property
54    def filepath(self):
55        """Absolute path of the file."""
56        return self._filepath
57
58    @property
59    def relpath(self):
60        """Relative path."""
61        try:
62            return os.path.relpath(self.filepath)
63        except OSError:
64            # current working directory may not be defined!
65            return self.filepath
66
67    @property
68    def basename(self):
69        """Basename of the file."""
70        return os.path.basename(self.filepath)
71
72    @property
73    def filetype(self):
74        """String defining the filetype."""
75        return self.__class__.__name__
76
77    def filestat(self, as_string=False):
78        """
79        Dictionary with file metadata, if ``as_string`` is True, a string is returned.
80        """
81        d = get_filestat(self.filepath)
82        if not as_string: return d
83        return "\n".join("%s: %s" % (k, v) for k, v in d.items())
84
85    @abc.abstractmethod
86    def close(self):
87        """Close the file."""
88
89    def __enter__(self):
90        return self
91
92    def __exit__(self, exc_type, exc_val, exc_tb):
93        """Activated at the end of the with statement. It automatically closes the file."""
94        self.close()
95
96
97class TextFile(BaseFile):
98
99    #@classmethood
100    #def from_string(cls, s):
101    #    return cls.from_file(filepath)
102
103    def __enter__(self):
104        # Open the file
105        self._file
106        return self
107
108    def __iter__(self):
109        return iter(self._file)
110
111    @lazy_property
112    def _file(self):
113        """File object open in read-only mode."""
114        return open(self.filepath, mode="rt")
115
116    def close(self):
117        """Close the file."""
118        try:
119            self._file.close()
120        except Exception:
121            pass
122
123    def seek(self, offset, whence=0):
124        """Set the file's current position, like stdio's fseek()."""
125        self._file.seek(offset, whence)
126
127
128class AbinitNcFile(BaseFile):
129    """
130    Abstract class representing a Netcdf file with data saved
131    according to the ETSF-IO specifications (when available).
132    An AbinitNcFile has a netcdf reader to read data from file and build objects.
133    """
134    def ncdump(self, *nc_args, **nc_kwargs):
135        """Returns a string with the output of ncdump."""
136        return NcDumper(*nc_args, **nc_kwargs).dump(self.filepath)
137
138    @lazy_property
139    def abinit_version(self):
140        """String with abinit version: three digits separated by comma."""
141        return self.reader.rootgrp.getncattr("abinit_version")
142
143    @abc.abstractproperty
144    def params(self):
145        """
146        :class:`OrderedDict` with the convergence parameters
147        Used to construct |pandas-DataFrames|.
148        """
149
150    #def get_abinit_input(self):
151    #    input_string = self.rootgrp.get_varname_set("input_string")
152    #    from abipy.abio.inputs import AbinitInput
153    #    return AbinitInput(structure, pseudos, pseudo_dir=None, abi_kwargs=None)
154
155
156class AbinitFortranFile(BaseFile):
157    """
158    Abstract class representing a fortran file containing output data from abinit.
159    """
160    def close(self):
161        pass
162
163
164class CubeFile(BaseFile):
165    """
166
167    .. attribute:: structure
168
169        |Structure| object
170
171    .. attribute:: mesh
172
173        |Mesh3d| object with information on the uniform 3d mesh.
174
175    .. attribute:: data
176
177        |numpy-array| of shape [nx, ny, nz] with numerical values on the real-space mesh.
178    """
179    def __init__(self, filepath):
180        from abipy.iotools.cube import cube_read_structure_mesh_data
181        super().__init__(filepath)
182        self.structure, self.mesh, self.data = cube_read_structure_mesh_data(self.filepath)
183
184    def close(self):
185        """nop, just to fulfill the abstract interface."""
186
187    #@classmethod
188    #def write_structure_mesh_data(cls, path, structure, mesh, data):
189    #    with open(path, "wt") as fh:
190    #        cube_write_structure_mesh(fh, structure, mesh)
191    #        cube_write_data(fh, data, mesh):
192
193
194class Has_Structure(metaclass=abc.ABCMeta):
195    """Mixin class for |AbinitNcFile| containing crystallographic data."""
196
197    @abc.abstractproperty
198    def structure(self):
199        """Returns the |Structure| object."""
200
201    def plot_bz(self, **kwargs):
202        """
203        Gives the plot (as a matplotlib object) of the symmetry line path in the Brillouin Zone.
204        """
205        return self.structure.plot_bz(**kwargs)
206
207    # To maintain backward compatbility
208    show_bz = plot_bz
209
210    def export_structure(self, filepath):
211        """
212        Export the structure on file.
213
214        returns: |Visualizer| instance.
215        """
216        return self.structure.export(filepath)
217
218    def visualize_structure_with(self, appname):
219        """
220        Visualize the crystalline structure with the specified visualizer.
221
222        See |Visualizer| for the list of applications and formats supported.
223        """
224        from abipy.iotools.visualizer import Visualizer
225        visu = Visualizer.from_name(appname)
226
227        for ext in visu.supported_extensions():
228            ext = "." + ext
229            try:
230                return self.export_structure(ext)
231            except visu.Error:
232                pass
233        else:
234            raise visu.Error("Don't know how to export data for appname %s" % appname)
235
236    def _get_atomview(self, view, select_symbols=None, verbose=0):
237        """
238        Helper function used to select (inequivalent||all) atoms depending on view.
239        Uses spglib to find inequivalent sites.
240
241        Args:
242            view: "inequivalent" to show only inequivalent atoms. "all" for all sites.
243            select_symbols: String or list of strings with chemical symbols.
244                Used to select only atoms of this type.
245
246        Return named tuple with:
247
248                * iatom_list: list of site index.
249                * wyckoffs: Wyckoff letters
250                * site_labels: Labels for each site in `iatom_list` e.g Si2a
251        """
252        natom = len(self.structure)
253        if natom == 1: verbose = False
254        if verbose:
255            print("Calling spglib to find inequivalent sites. Magnetic symmetries (if any) are not taken into account.")
256
257        ea = self.structure.spget_equivalent_atoms(printout=verbose > 0)
258
259        # Define iatom_list depending on view
260        if view == "all":
261            iatom_list = np.arange(natom)
262        elif view == "inequivalent":
263            iatom_list = ea.irred_pos
264        else:
265            raise ValueError("Wrong value for view: %s" % str(view))
266
267        # Filter by element symbol.
268        if select_symbols is not None:
269            select_symbols = set(list_strings(select_symbols))
270            iatom_list = [i for i in iatom_list if self.structure[i].specie.symbol in select_symbols]
271            iatom_list = np.array(iatom_list, dtype=int)
272
273        # Slice full arrays.
274        wyckoffs = ea.wyckoffs[iatom_list]
275        wyck_labels = ea.wyck_labels[iatom_list]
276        site_labels = ea.site_labels[iatom_list]
277
278        return dict2namedtuple(iatom_list=iatom_list, wyckoffs=wyckoffs, wyck_labels=wyck_labels, site_labels=site_labels)
279
280    def yield_structure_figs(self, **kwargs):
281        """*Generates* a predefined list of matplotlib figures with minimal input from the user."""
282        yield self.structure.plot(show=False)
283
284
285class Has_ElectronBands(metaclass=abc.ABCMeta):
286    """Mixin class for |AbinitNcFile| containing electron data."""
287
288    @abc.abstractproperty
289    def ebands(self):
290        """Returns the |ElectronBands| object."""
291
292    @property
293    def nsppol(self):
294        """Number of spin polarizations"""
295        return self.ebands.nsppol
296
297    @property
298    def nspinor(self):
299        """Number of spinors"""
300        return self.ebands.nspinor
301
302    @property
303    def nspden(self):
304        """Number of indepedendent spin-density components."""
305        return self.ebands.nspden
306
307    @property
308    def mband(self):
309        """Maximum number of bands."""
310        return self.ebands.mband
311
312    @property
313    def nband(self):
314        """Maximum number of bands."""
315        return self.ebands.nband
316
317    @property
318    def nelect(self):
319        """Number of electrons per unit cell"""
320        return self.ebands.nelect
321
322    @property
323    def nkpt(self):
324        """Number of k-points."""
325        return self.ebands.nkpt
326
327    @property
328    def kpoints(self):
329        """Iterable with the Kpoints."""
330        return self.ebands.kpoints
331
332    @lazy_property
333    def tsmear(self):
334        return self.ebands.smearing.tsmear_ev.to("Ha")
335
336    def get_ebands_params(self):
337        """:class:`OrderedDict` with the convergence parameters."""
338        return collections.OrderedDict([
339            ("nsppol", self.nsppol),
340            ("nspinor", self.nspinor),
341            ("nspden", self.nspden),
342            ("nband", self.nband),
343            ("nkpt", self.nkpt),
344        ])
345
346    def plot_ebands(self, **kwargs):
347        """Plot the electron energy bands. See the :func:`ElectronBands.plot` for the signature."""
348        return self.ebands.plot(**kwargs)
349
350    def plot_ebands_with_edos(self, edos, **kwargs):
351        """Plot the electron energy bands with DOS. See the :func:`ElectronBands.plot_with_edos` for the signature."""
352        return self.ebands.plot_with_edos(edos, **kwargs)
353
354    def get_edos(self, **kwargs):
355        """Compute the electronic DOS on a linear mesh. Wraps ebands.get_edos."""
356        return self.ebands.get_edos(**kwargs)
357
358    def yield_ebands_figs(self, **kwargs):
359        """*Generates* a predefined list of matplotlib figures with minimal input from the user."""
360        with_gaps = not self.ebands.has_metallic_scheme
361        if self.ebands.kpoints.is_path:
362            yield self.ebands.plot(with_gaps=with_gaps, show=False)
363            yield self.ebands.kpoints.plot(show=False)
364        else:
365            edos = self.ebands.get_edos()
366            yield self.ebands.plot_with_edos(edos, with_gaps=with_gaps, show=False)
367            yield edos.plot(show=False)
368
369    def expose_ebands(self, slide_mode=False, slide_timeout=None, **kwargs):
370        """
371        Shows a predefined list of matplotlib figures for electron bands with minimal input from the user.
372        """
373        from abipy.tools.plotting import MplExpose
374        with MplExpose(slide_mode=slide_mode, slide_timeout=slide_mode, verbose=1) as e:
375            e(self.yield_ebands_figs(**kwargs))
376
377
378class Has_PhononBands(metaclass=abc.ABCMeta):
379    """
380    Mixin class for |AbinitNcFile| containing phonon data.
381    """
382
383    @abc.abstractproperty
384    def phbands(self):
385        """Returns the |PhononBands| object."""
386
387    def get_phbands_params(self):
388        """:class:`OrderedDict` with the convergence parameters."""
389        return collections.OrderedDict([
390            ("nqpt", len(self.phbands.qpoints)),
391        ])
392
393    def plot_phbands(self, **kwargs):
394        """
395        Plot the electron energy bands. See the :func:`PhononBands.plot` for the signature.""
396        """
397        return self.phbands.plot(**kwargs)
398
399    #def plot_phbands_with_phdos(self, phdos, **kwargs):
400    #    return self.phbands.plot_with_phdos(phdos, **kwargs)
401
402    def yield_phbands_figs(self, **kwargs):  # pragma: no cover
403        """
404        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
405        Used in abiview.py to get a quick look at the results.
406        """
407        units = kwargs.get("units", "mev")
408        yield self.phbands.qpoints.plot(show=False)
409        yield self.phbands.plot(units=units, show=False)
410        yield self.phbands.plot_colored_matched(units=units, show=False)
411
412    def expose_phbands(self, slide_mode=False, slide_timeout=None, **kwargs):
413        """
414        Shows a predefined list of matplotlib figures for phonon bands with minimal input from the user.
415        """
416        from abipy.tools.plotting import MplExpose
417        with MplExpose(slide_mode=slide_mode, slide_timeout=slide_mode, verbose=1) as e:
418            e(self.yield_phbands_figs(**kwargs))
419
420
421class NcDumper(object):
422    """Wrapper object for the ncdump tool."""
423
424    def __init__(self, *nc_args, **nc_kwargs):
425        """
426        Args:
427            nc_args: Arguments passed to ncdump.
428            nc_kwargs: Keyword arguments passed to ncdump
429        """
430        self.nc_args = nc_args
431        self.nc_kwargs = nc_kwargs
432        self.ncdump = which("ncdump")
433
434    def dump(self, filepath):
435        """Returns a string with the output of ncdump."""
436        if self.ncdump is None:
437            return "Cannot find ncdump tool in $PATH"
438        else:
439            from subprocess import check_output
440            return check_output(["ncdump", filepath])
441
442
443_ABBREVS = [
444    (1 << 50, 'Pb'),
445    (1 << 40, 'Tb'),
446    (1 << 30, 'Gb'),
447    (1 << 20, 'Mb'),
448    (1 << 10, 'kb'),
449    (1, 'b'),
450]
451
452
453def size2str(size):
454    """Convert size to string with units."""
455    for factor, suffix in _ABBREVS:
456        if size > factor:
457            break
458    return "%.2f " % (size / factor) + suffix
459
460
461def get_filestat(filepath):
462    stat = os.stat(filepath)
463    return collections.OrderedDict([
464        ("Name", os.path.basename(filepath)),
465        ("Directory", os.path.dirname(filepath)),
466        ("Size", size2str(stat.st_size)),
467        ("Access Time", ctime(stat.st_atime)),
468        ("Modification Time", ctime(stat.st_mtime)),
469        ("Change Time", ctime(stat.st_ctime)),
470    ])
471
472
473class NotebookWriter(metaclass=abc.ABCMeta):
474    """
475    Mixin class for objects that are able to generate jupyter_ notebooks.
476    Subclasses must provide a concrete implementation of `write_notebook`.
477    """
478
479    def make_and_open_notebook(self, nbpath=None, foreground=False,
480                               classic_notebook=False, no_browser=False):  # pragma: no cover
481        """
482        Generate an jupyter_ notebook and open it in the browser.
483
484        Args:
485            nbpath: If nbpath is None, a temporay file is created.
486            foreground: By default, jupyter is executed in background and stdout, stderr are redirected.
487            to devnull. Use foreground to run the process in foreground
488            classic_notebook: True to use the classic notebook instead of jupyter-lab (default)
489            no_browser: Start the jupyter server to serve the notebook but don't open the notebook in the browser.
490                        Use this option to connect remotely from localhost to the machine running the kernel
491
492        Return: system exit code.
493
494        Raise: `RuntimeError` if jupyter executable is not in $PATH
495        """
496        nbpath = self.write_notebook(nbpath=nbpath)
497
498        if not classic_notebook:
499            # Use jupyter-lab.
500            app_path = which("jupyter-lab")
501            if app_path is None:
502                raise RuntimeError("""
503Cannot find jupyter-lab application in $PATH. Install it with:
504
505    conda install -c conda-forge jupyterlab
506
507or:
508
509    pip install jupyterlab
510
511See also https://jupyterlab.readthedocs.io/
512""")
513
514        else:
515            # Use classic notebook
516            app_path = which("jupyter")
517            if app_path is None:
518                raise RuntimeError("""
519Cannot find jupyter application in $PATH. Install it with:
520
521    conda install -c conda-forge jupyter
522
523or:
524
525    pip install jupyterlab
526
527See also https://jupyter.readthedocs.io/en/latest/install.html
528""")
529            app_path = app_path + " notebook "
530
531        if not no_browser:
532
533            if foreground:
534                return os.system("%s %s" % (app_path, nbpath))
535            else:
536                fd, tmpname = tempfile.mkstemp(text=True)
537                print(tmpname)
538                cmd = "%s %s" % (app_path, nbpath)
539                print("Executing:", cmd, "\nstdout and stderr redirected to %s" % tmpname)
540                import subprocess
541                process = subprocess.Popen(cmd.split(), shell=False, stdout=fd, stderr=fd)
542                cprint("pid: %s" % str(process.pid), "yellow")
543                return 0
544
545        else:
546            # Based on https://github.com/arose/nglview/blob/master/nglview/scripts/nglview.py
547            notebook_name = os.path.basename(nbpath)
548            dirname = os.path.dirname(nbpath)
549            print("nbpath:", nbpath)
550
551            import socket
552
553            def find_free_port():
554                """https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number"""
555                from contextlib import closing
556                with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
557                    s.bind(('', 0))
558                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
559                    return s.getsockname()[1]
560
561            username = os.getlogin()
562            hostname = socket.gethostname()
563            port = find_free_port()
564
565            client_cmd = "ssh -NL localhost:{port}:localhost:{port} {username}@{hostname}".format(
566                username=username, hostname=hostname, port=port)
567
568            print(f"""
569Using port: {port}
570
571\033[32m In your local machine, run: \033[0m
572
573                {client_cmd}
574
575\033[32m NOTE: you might want to replace {hostname} by full hostname with domain name \033[0m
576\033[32m Then open your web browser, copy and paste the URL: \033[0m
577
578http://localhost:{port}/notebooks/{notebook_name}
579""")
580            if not classic_notebook:
581                cmd = f'{app_path} {notebook_name} --no-browser --port {port} --notebook-dir {dirname}'
582            else:
583                cmd = f'{app_path} notebook {notebook_name} --no-browser --port {port} --notebook-dir {dirname}'
584
585            print("Executing:", cmd)
586            print('NOTE: make sure to open `{}` in your local machine\n'.format(notebook_name))
587
588            return os.system(cmd)
589
590    @staticmethod
591    def get_nbformat_nbv():
592        """Return nbformat module, notebook version module"""
593        import nbformat
594        nbv = nbformat.v4
595        return nbformat, nbv
596
597    def get_nbformat_nbv_nb(self, title=None):
598        """
599        Return ``nbformat`` module, notebook version module
600        and new notebook with title and import section
601        """
602        nbformat, nbv = self.get_nbformat_nbv()
603        nb = nbv.new_notebook()
604
605        if title is not None:
606            nb.cells.append(nbv.new_markdown_cell("## %s" % title))
607
608        nb.cells.extend([
609            nbv.new_code_cell("""\
610import sys, os
611import numpy as np
612
613%matplotlib notebook
614
615# Use this magic for jupyterlab.
616# For installation instructions, see https://github.com/matplotlib/jupyter-matplotlib
617#%matplotlib widget
618
619from IPython.display import display
620
621# This to render pandas DataFrames with https://github.com/quantopian/qgrid
622#import qgrid
623#qgrid.nbinstall(overwrite=True)  # copies javascript dependencies to your /nbextensions folder
624
625# This to view Mayavi visualizations. See http://docs.enthought.com/mayavi/mayavi/tips.html
626#from mayavi import mlab; mlab.init_notebook(backend='x3d', width=None, height=None, local=True)
627
628from abipy import abilab
629
630# Tell AbiPy we are inside a notebook and use seaborn settings for plots.
631# See https://seaborn.pydata.org/generated/seaborn.set.html#seaborn.set
632abilab.enable_notebook(with_seaborn=True)
633""")
634        ])
635
636        return nbformat, nbv, nb
637
638    @abc.abstractmethod
639    def write_notebook(self, nbpath=None):
640        """
641        Write a jupyter_ notebook to nbpath. If nbpath is None, a temporay file is created.
642        Return path to the notebook. A typical template:
643
644        .. code-block:: python
645
646            # Preable.
647            nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
648
649            #####################
650            # Put your code here
651            nb.cells.extend([
652                nbv.new_markdown_cell("# This is a markdown cell"),
653                nbv.new_code_cell("a = 1"),
654            ])
655            #####################
656
657            # Call _write_nb_nbpath
658            return self._write_nb_nbpath(nb, nbpath)
659        """
660
661    @staticmethod
662    def _write_nb_nbpath(nb, nbpath):
663        """
664        This method must be called at the end of ``write_notebook``.
665        nb is the jupyter notebook and nbpath the argument passed to ``write_notebook``.
666        """
667        import io, os, tempfile
668        if nbpath is None:
669            _, nbpath = tempfile.mkstemp(prefix="abinb_", suffix='.ipynb', dir=os.getcwd(), text=True)
670
671        # Write notebook
672        import nbformat
673        with io.open(nbpath, 'wt', encoding="utf8") as fh:
674            nbformat.write(nb, fh)
675            return nbpath
676
677    @classmethod
678    def pickle_load(cls, filepath):
679        """
680        Loads the object from a pickle file.
681        """
682        with open(filepath, "rb") as fh:
683            new = pickle.load(fh)
684            #assert cls is new.__class__
685            return new
686
687    def pickle_dump(self, filepath=None):
688        """
689        Save the status of the object in pickle format.
690        If filepath is None, a temporary file is created.
691
692        Return: The name of the pickle file.
693        """
694        if filepath is None:
695            _, filepath = tempfile.mkstemp(suffix='.pickle')
696
697        with open(filepath, "wb") as fh:
698            pickle.dump(self, fh)
699            return filepath
700
701    @abc.abstractmethod
702    def yield_figs(self, **kwargs):  # pragma: no cover
703        """
704        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
705        Used in abiview.py to get a quick look at the results.
706        """
707
708    def expose(self, slide_mode=False, slide_timeout=None, **kwargs):
709        """
710        Shows a predefined list of matplotlib figures with minimal input from the user.
711        """
712        from abipy.tools.plotting import MplExpose
713        with MplExpose(slide_mode=slide_mode, slide_timeout=slide_mode, verbose=1) as e:
714            e(self.yield_figs(**kwargs))
715
716
717class Has_Header(object):
718    """Mixin class for netcdf files containing the Abinit header."""
719
720    @lazy_property
721    def hdr(self):
722        """|AttrDict| with the Abinit header e.g. hdr.ecut."""
723        return self.reader.read_abinit_hdr()
724
725    #def compare_hdr(self, other_hdr):
726