1# coding: utf-8
2# flake8: noqa
3import numpy as np
4import pymatgen.io.abinit.netcdf as ionc
5
6from monty.functools import lazy_property
7from pymatgen.core.periodic_table import Element
8from .xsf import *
9from .visualizer import *
10
11
12as_etsfreader = ionc.as_etsfreader
13
14
15class ETSF_Reader(ionc.ETSF_Reader):
16    """
17    Provides high-level API to read data from netcdf files written
18    following the ETSF-IO specifications described in :cite:`Caliste2008`
19    """
20
21    def read_structure(self):
22        """
23        Overrides the ``read_structure`` method so that we always return
24        an instance of AbiPy |Structure| object
25        """
26        from abipy.core.structure import Structure
27        return Structure.from_file(self.path)
28
29    # Must overwrite implementation of pymatgen.io.abinit.netcdf
30    # due to a possible bug introduced by initial whitespaces in symbol
31    @lazy_property
32    def chemical_symbols(self):
33        """Chemical symbols char [number of atom species][symbol length]."""
34        charr = self.read_value("chemical_symbols")
35        symbols = []
36        for v in charr:
37            s = "".join(c.decode("utf-8") for c in v)
38            # Strip to avoid possible whitespaces.
39            symbols.append(s.strip())
40
41        return symbols
42
43    def read_string(self, varname):
44        """
45        Args:
46            varname: Name of the variable
47        """
48        b = self.rootgrp.variables[varname][:]
49        #print(type(b))
50        import netCDF4
51        try:
52            value = netCDF4.chartostring(b)[()].decode('utf-8')
53        except Exception:
54            try:
55                value = netCDF4.chartostring(b)[()]
56            except Exception:
57                try:
58                    value = "".join(c for c in self.read_value(varname))
59                except TypeError as exc:
60                    value = "".join(c.decode("utf-8") for c in self.read_value(varname))
61
62        return value.strip()
63
64    def none_if_masked_array(self, arr):
65        """Return None if arr is a MaskedArray else None."""
66        return None if np.ma.is_masked(arr) else arr
67
68    def read_amu_symbol(self):
69        """
70        Read atomic masses and return dictionary element_symbol --> amu.
71
72        .. note::
73
74            Only netcdf files with phonon-related quantities contain this variable.
75        """
76        for k in ("atomic_mass_units", "atomic_numbers"):
77            if k not in self.rootgrp.variables:
78                raise RuntimeError("`%s` does not contain `%s` variable." % (self.path, k))
79
80        # ntypat arrays
81        amu_list = self.read_value("atomic_mass_units")
82        atomic_numbers = self.read_value("atomic_numbers")
83        amu_z = {at: a for at, a in zip(atomic_numbers, amu_list)}
84        amu_symbol = {Element.from_Z(n).symbol: v for n, v in amu_z.items()}
85
86        return amu_symbol
87
88    def read_ngfft3(self):
89        """
90        Return the number of FFT divisions.
91        """
92        ngfft3 = 3 * [None]
93        ngfft3[0] = self.read_dimvalue("number_of_grid_points_vector1")
94        ngfft3[1] = self.read_dimvalue("number_of_grid_points_vector2")
95        ngfft3[2] = self.read_dimvalue("number_of_grid_points_vector3")
96        return np.array(ngfft3, int)
97