1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4#
5# pylint: disable=no-member
6"""Wrapper for netCDF readers."""
7
8import logging
9import os.path
10import warnings
11from collections import OrderedDict
12
13import numpy as np
14from monty.collections import AttrDict
15from monty.dev import requires
16from monty.functools import lazy_property
17from monty.string import marquee
18
19from pymatgen.core.structure import Structure
20from pymatgen.core.units import ArrayWithUnit
21from pymatgen.core.xcfunc import XcFunc
22
23logger = logging.getLogger(__name__)
24
25__author__ = "Matteo Giantomassi"
26__copyright__ = "Copyright 2013, The Materials Project"
27__version__ = "0.1"
28__maintainer__ = "Matteo Giantomassi"
29__email__ = "gmatteo at gmail.com"
30__status__ = "Development"
31__date__ = "$Feb 21, 2013M$"
32
33__all__ = [
34    "as_ncreader",
35    "as_etsfreader",
36    "NetcdfReader",
37    "ETSF_Reader",
38    "NO_DEFAULT",
39    "structure_from_ncdata",
40]
41
42try:
43    import netCDF4
44except ImportError as exc:
45    netCDF4 = None
46    warnings.warn(
47        """\
48`import netCDF4` failed with the following error:
49
50%s
51
52Please install netcdf4 with `conda install netcdf4`
53If the conda version does not work, uninstall it with `conda uninstall hdf4 hdf5 netcdf4`
54and use `pip install netcdf4`"""
55        % str(exc)
56    )
57
58
59def _asreader(file, cls):
60    closeit = False
61    if not isinstance(file, cls):
62        file, closeit = cls(file), True
63    return file, closeit
64
65
66def as_ncreader(file):
67    """
68    Convert file into a NetcdfReader instance.
69    Returns reader, closeit where closeit is set to True
70    if we have to close the file before leaving the procedure.
71    """
72    return _asreader(file, NetcdfReader)
73
74
75def as_etsfreader(file):
76    """Return an ETSF_Reader. Accepts filename or ETSF_Reader."""
77    return _asreader(file, ETSF_Reader)
78
79
80class NetcdfReaderError(Exception):
81    """Base error class for NetcdfReader"""
82
83
84class NO_DEFAULT:
85    """Signal that read_value should raise an Error"""
86
87
88class NetcdfReader:
89    """
90    Wraps and extends netCDF4.Dataset. Read only mode. Supports with statements.
91
92    Additional documentation available at:
93        http://netcdf4-python.googlecode.com/svn/trunk/docs/netCDF4-module.html
94    """
95
96    Error = NetcdfReaderError
97
98    @requires(netCDF4 is not None, "netCDF4 must be installed to use this class")
99    def __init__(self, path):
100        """Open the Netcdf file specified by path (read mode)."""
101        self.path = os.path.abspath(path)
102
103        try:
104            self.rootgrp = netCDF4.Dataset(self.path, mode="r")
105        except Exception as exc:
106            raise self.Error("In file %s: %s" % (self.path, str(exc)))
107
108        self.ngroups = len(list(self.walk_tree()))
109
110        # Always return non-masked numpy arrays.
111        # Slicing a ncvar returns a MaskedArrray and this is really annoying
112        # because it can lead to unexpected behaviour in e.g. calls to np.matmul!
113        # See also https://github.com/Unidata/netcdf4-python/issues/785
114        self.rootgrp.set_auto_mask(False)
115
116    def __enter__(self):
117        """Activated when used in the with statement."""
118        return self
119
120    def __exit__(self, type, value, traceback):
121        """Activated at the end of the with statement. It automatically closes the file."""
122        self.rootgrp.close()
123
124    def close(self):
125        """Close the file."""
126        try:
127            self.rootgrp.close()
128        except Exception as exc:
129            logger.warning("Exception %s while trying to close %s" % (exc, self.path))
130
131    def walk_tree(self, top=None):
132        """
133        Navigate all the groups in the file starting from top.
134        If top is None, the root group is used.
135        """
136        if top is None:
137            top = self.rootgrp
138
139        values = top.groups.values()
140        yield values
141        for value in top.groups.values():
142            for children in self.walk_tree(value):
143                yield children
144
145    def print_tree(self):
146        """Print all the groups in the file."""
147        for children in self.walk_tree():
148            for child in children:
149                print(child)
150
151    def read_dimvalue(self, dimname, path="/", default=NO_DEFAULT):
152        """
153        Returns the value of a dimension.
154
155        Args:
156            dimname: Name of the variable
157            path: path to the group.
158            default: return `default` if `dimname` is not present and
159                `default` is not `NO_DEFAULT` else raise self.Error.
160        """
161        try:
162            dim = self._read_dimensions(dimname, path=path)[0]
163            return len(dim)
164        except self.Error:
165            if default is NO_DEFAULT:
166                raise
167            return default
168
169    def read_varnames(self, path="/"):
170        """List of variable names stored in the group specified by path."""
171        if path == "/":
172            return self.rootgrp.variables.keys()
173        group = self.path2group[path]
174        return group.variables.keys()
175
176    def read_value(self, varname, path="/", cmode=None, default=NO_DEFAULT):
177        """
178        Returns the values of variable with name varname in the group specified by path.
179
180        Args:
181            varname: Name of the variable
182            path: path to the group.
183            cmode: if cmode=="c", a complex ndarrays is constructed and returned
184                (netcdf does not provide native support from complex datatype).
185            default: returns default if varname is not present.
186                self.Error is raised if default is set to NO_DEFAULT
187
188        Returns:
189            numpy array if varname represents an array, scalar otherwise.
190        """
191        try:
192            var = self.read_variable(varname, path=path)
193        except self.Error:
194            if default is NO_DEFAULT:
195                raise
196            return default
197
198        if cmode is None:
199            # scalar or array
200            # getValue is not portable!
201            try:
202                return var.getValue()[0] if not var.shape else var[:]
203            except IndexError:
204                return var.getValue() if not var.shape else var[:]
205
206        assert var.shape[-1] == 2
207        if cmode == "c":
208            return var[..., 0] + 1j * var[..., 1]
209        raise ValueError("Wrong value for cmode %s" % cmode)
210
211    def read_variable(self, varname, path="/"):
212        """Returns the variable with name varname in the group specified by path."""
213        return self._read_variables(varname, path=path)[0]
214
215    def _read_dimensions(self, *dimnames, **kwargs):
216        path = kwargs.get("path", "/")
217        try:
218            if path == "/":
219                return [self.rootgrp.dimensions[dname] for dname in dimnames]
220            group = self.path2group[path]
221            return [group.dimensions[dname] for dname in dimnames]
222
223        except KeyError:
224            raise self.Error(
225                "In file %s:\nError while reading dimensions: `%s` with kwargs: `%s`" % (self.path, dimnames, kwargs)
226            )
227
228    def _read_variables(self, *varnames, **kwargs):
229        path = kwargs.get("path", "/")
230        try:
231            if path == "/":
232                return [self.rootgrp.variables[vname] for vname in varnames]
233            group = self.path2group[path]
234            return [group.variables[vname] for vname in varnames]
235
236        except KeyError:
237            raise self.Error(
238                "In file %s:\nError while reading variables: `%s` with kwargs `%s`." % (self.path, varnames, kwargs)
239            )
240
241    def read_keys(self, keys, dict_cls=AttrDict, path="/"):
242        """
243        Read a list of variables/dimensions from file. If a key is not present the corresponding
244        entry in the output dictionary is set to None.
245        """
246        od = dict_cls()
247        for k in keys:
248            try:
249                # Try to read a variable.
250                od[k] = self.read_value(k, path=path)
251            except self.Error:
252                try:
253                    # Try to read a dimension.
254                    od[k] = self.read_dimvalue(k, path=path)
255                except self.Error:
256                    od[k] = None
257
258        return od
259
260
261class ETSF_Reader(NetcdfReader):
262    """
263    This object reads data from a file written according to the ETSF-IO specifications.
264
265    We assume that the netcdf file contains at least the crystallographic section.
266    """
267
268    @lazy_property
269    def chemical_symbols(self):
270        """Chemical symbols char [number of atom species][symbol length]."""
271        charr = self.read_value("chemical_symbols")
272        symbols = []
273        for v in charr:
274            s = "".join(c.decode("utf-8") for c in v)
275            symbols.append(s.strip())
276
277        return symbols
278
279    def typeidx_from_symbol(self, symbol):
280        """Returns the type index from the chemical symbol. Note python convention."""
281        return self.chemical_symbols.index(symbol)
282
283    def read_structure(self, cls=Structure):
284        """Returns the crystalline structure stored in the rootgrp."""
285        return structure_from_ncdata(self, cls=cls)
286
287    def read_abinit_xcfunc(self):
288        """
289        Read ixc from an Abinit file. Return :class:`XcFunc` object.
290        """
291        ixc = int(self.read_value("ixc"))
292        return XcFunc.from_abinit_ixc(ixc)
293
294    def read_abinit_hdr(self):
295        """
296        Read the variables associated to the Abinit header.
297
298        Return :class:`AbinitHeader`
299        """
300        d = {}
301        for hvar in _HDR_VARIABLES.values():
302            ncname = hvar.etsf_name if hvar.etsf_name is not None else hvar.name
303            if ncname in self.rootgrp.variables:
304                d[hvar.name] = self.read_value(ncname)
305            elif ncname in self.rootgrp.dimensions:
306                d[hvar.name] = self.read_dimvalue(ncname)
307            else:
308                raise ValueError("Cannot find `%s` in `%s`" % (ncname, self.path))
309            # Convert scalars to (well) scalars.
310            if hasattr(d[hvar.name], "shape") and not d[hvar.name].shape:
311                d[hvar.name] = np.asarray(d[hvar.name]).item()
312            if hvar.name in ("title", "md5_pseudos", "codvsn"):
313                # Convert array of numpy bytes to list of strings
314                if hvar.name == "codvsn":
315                    d[hvar.name] = "".join(bs.decode("utf-8").strip() for bs in d[hvar.name])
316                else:
317                    d[hvar.name] = ["".join(bs.decode("utf-8") for bs in astr).strip() for astr in d[hvar.name]]
318
319        return AbinitHeader(d)
320
321
322def structure_from_ncdata(ncdata, site_properties=None, cls=Structure):
323    """
324    Reads and returns a pymatgen structure from a NetCDF file
325    containing crystallographic data in the ETSF-IO format.
326
327    Args:
328        ncdata: filename or NetcdfReader instance.
329        site_properties: Dictionary with site properties.
330        cls: The Structure class to instanciate.
331    """
332    ncdata, closeit = as_ncreader(ncdata)
333
334    # TODO check whether atomic units are used
335    lattice = ArrayWithUnit(ncdata.read_value("primitive_vectors"), "bohr").to("ang")
336
337    red_coords = ncdata.read_value("reduced_atom_positions")
338    natom = len(red_coords)
339
340    znucl_type = ncdata.read_value("atomic_numbers")
341
342    # type_atom[0:natom] --> index Between 1 and number of atom species
343    type_atom = ncdata.read_value("atom_species")
344
345    # Fortran to C index and float --> int conversion.
346    species = natom * [None]
347    for atom in range(natom):
348        type_idx = type_atom[atom] - 1
349        species[atom] = int(znucl_type[type_idx])
350
351    d = {}
352    if site_properties is not None:
353        for prop in site_properties:
354            d[prop] = ncdata.read_value(prop)
355
356    structure = cls(lattice, species, red_coords, site_properties=d)
357
358    # Quick and dirty hack.
359    # I need an abipy structure since I need to_abivars and other methods.
360    try:
361        from abipy.core.structure import Structure as AbipyStructure
362
363        structure.__class__ = AbipyStructure
364    except ImportError:
365        pass
366
367    if closeit:
368        ncdata.close()
369
370    return structure
371
372
373class _H:
374    __slots__ = ["name", "doc", "etsf_name"]
375
376    def __init__(self, name, doc, etsf_name=None):
377        self.name, self.doc, self.etsf_name = name, doc, etsf_name
378
379
380_HDR_VARIABLES = (
381    # Scalars
382    _H("bantot", "total number of bands (sum of nband on all kpts and spins)"),
383    _H("date", "starting date"),
384    _H("headform", "format of the header"),
385    _H("intxc", "input variable"),
386    _H("ixc", "input variable"),
387    _H("mband", "maxval(hdr%nband)", etsf_name="max_number_of_states"),
388    _H("natom", "input variable", etsf_name="number_of_atoms"),
389    _H("nkpt", "input variable", etsf_name="number_of_kpoints"),
390    _H("npsp", "input variable"),
391    _H("nspden", "input variable", etsf_name="number_of_components"),
392    _H("nspinor", "input variable", etsf_name="number_of_spinor_components"),
393    _H("nsppol", "input variable", etsf_name="number_of_spins"),
394    _H("nsym", "input variable", etsf_name="number_of_symmetry_operations"),
395    _H("ntypat", "input variable", etsf_name="number_of_atom_species"),
396    _H("occopt", "input variable"),
397    _H("pertcase", "the index of the perturbation, 0 if GS calculation"),
398    _H("usepaw", "input variable (0=norm-conserving psps, 1=paw)"),
399    _H("usewvl", "input variable (0=plane-waves, 1=wavelets)"),
400    _H("kptopt", "input variable (defines symmetries used for k-point sampling)"),
401    _H("pawcpxocc", "input variable"),
402    _H(
403        "nshiftk_orig",
404        "original number of shifts given in input (changed in inkpts, the actual value is nshiftk)",
405    ),
406    _H("nshiftk", "number of shifts after inkpts."),
407    _H("icoulomb", "input variable."),
408    _H("ecut", "input variable", etsf_name="kinetic_energy_cutoff"),
409    _H("ecutdg", "input variable (ecut for NC psps, pawecutdg for paw)"),
410    _H("ecutsm", "input variable"),
411    _H("ecut_eff", "ecut*dilatmx**2 (dilatmx is an input variable)"),
412    _H("etot", "EVOLVING variable"),
413    _H("fermie", "EVOLVING variable", etsf_name="fermi_energy"),
414    _H("residm", "EVOLVING variable"),
415    _H("stmbias", "input variable"),
416    _H("tphysel", "input variable"),
417    _H("tsmear", "input variable"),
418    _H("nelect", "number of electrons (computed from pseudos and charge)"),
419    _H("charge", "input variable"),
420    # Arrays
421    _H("qptn", "qptn(3) the wavevector, in case of a perturbation"),
422    # _H("rprimd", "rprimd(3,3) EVOLVING variables", etsf_name="primitive_vectors"),
423    # _H(ngfft, "ngfft(3) input variable",  number_of_grid_points_vector1"
424    # _H("nwvlarr", "nwvlarr(2) the number of wavelets for each resolution.", etsf_name="number_of_wavelets"),
425    _H("kptrlatt_orig", "kptrlatt_orig(3,3) Original kptrlatt"),
426    _H("kptrlatt", "kptrlatt(3,3) kptrlatt after inkpts."),
427    _H("istwfk", "input variable istwfk(nkpt)"),
428    _H("lmn_size", "lmn_size(npsp) from psps"),
429    _H("nband", "input variable nband(nkpt*nsppol)", etsf_name="number_of_states"),
430    _H(
431        "npwarr",
432        "npwarr(nkpt) array holding npw for each k point",
433        etsf_name="number_of_coefficients",
434    ),
435    _H("pspcod", "pscod(npsp) from psps"),
436    _H("pspdat", "psdat(npsp) from psps"),
437    _H("pspso", "pspso(npsp) from psps"),
438    _H("pspxc", "pspxc(npsp) from psps"),
439    _H("so_psp", "input variable so_psp(npsp)"),
440    _H("symafm", "input variable symafm(nsym)"),
441    # _H(symrel="input variable symrel(3,3,nsym)",  etsf_name="reduced_symmetry_matrices"),
442    _H("typat", "input variable typat(natom)", etsf_name="atom_species"),
443    _H(
444        "kptns",
445        "input variable kptns(nkpt, 3)",
446        etsf_name="reduced_coordinates_of_kpoints",
447    ),
448    _H("occ", "EVOLVING variable occ(mband, nkpt, nsppol)", etsf_name="occupations"),
449    _H(
450        "tnons",
451        "input variable tnons(nsym, 3)",
452        etsf_name="reduced_symmetry_translations",
453    ),
454    _H("wtk", "weight of kpoints wtk(nkpt)", etsf_name="kpoint_weights"),
455    _H("shiftk_orig", "original shifts given in input (changed in inkpts)."),
456    _H("shiftk", "shiftk(3,nshiftk), shiftks after inkpts"),
457    _H("amu", "amu(ntypat) ! EVOLVING variable"),
458    # _H("xred", "EVOLVING variable xred(3,natom)", etsf_name="reduced_atom_positions"),
459    _H("zionpsp", "zionpsp(npsp) from psps"),
460    _H(
461        "znuclpsp",
462        "znuclpsp(npsp) from psps. Note the difference between (znucl|znucltypat) and znuclpsp",
463    ),
464    _H("znucltypat", "znucltypat(ntypat) from alchemy", etsf_name="atomic_numbers"),
465    _H("codvsn", "version of the code"),
466    _H("title", "title(npsp) from psps"),
467    _H(
468        "md5_pseudos",
469        "md5pseudos(npsp), md5 checksums associated to pseudos (read from file)",
470    ),
471    # _H(type(pawrhoij_type), allocatable :: pawrhoij(:) ! EVOLVING variable, only for paw
472)
473_HDR_VARIABLES = OrderedDict([(h.name, h) for h in _HDR_VARIABLES])  # type: ignore
474
475
476class AbinitHeader(AttrDict):
477    """Stores the values reported in the Abinit header."""
478
479    # def __init__(self, *args, **kwargs):
480    #    super().__init__(*args, **kwargs)
481    #    for k, v in self.items():
482    #        v.__doc__ = _HDR_VARIABLES[k].doc
483
484    def __str__(self):
485        return self.to_string()
486
487    def to_string(self, verbose=0, title=None, **kwargs):
488        """
489        String representation. kwargs are passed to `pprint.pformat`.
490
491        Args:
492            verbose: Verbosity level
493            title: Title string.
494        """
495        from pprint import pformat
496
497        s = pformat(self, **kwargs)
498        if title is not None:
499            return "\n".join([marquee(title, mark="="), s])
500        return s
501