1# coding: utf-8 2# Copyright (c) Pymatgen Development Team. 3# Distributed under the terms of the MIT License. 4# 5# pylint: disable=no-member 6"""Wrapper for netCDF readers.""" 7 8import logging 9import os.path 10import warnings 11from collections import OrderedDict 12 13import numpy as np 14from monty.collections import AttrDict 15from monty.dev import requires 16from monty.functools import lazy_property 17from monty.string import marquee 18 19from pymatgen.core.structure import Structure 20from pymatgen.core.units import ArrayWithUnit 21from pymatgen.core.xcfunc import XcFunc 22 23logger = logging.getLogger(__name__) 24 25__author__ = "Matteo Giantomassi" 26__copyright__ = "Copyright 2013, The Materials Project" 27__version__ = "0.1" 28__maintainer__ = "Matteo Giantomassi" 29__email__ = "gmatteo at gmail.com" 30__status__ = "Development" 31__date__ = "$Feb 21, 2013M$" 32 33__all__ = [ 34 "as_ncreader", 35 "as_etsfreader", 36 "NetcdfReader", 37 "ETSF_Reader", 38 "NO_DEFAULT", 39 "structure_from_ncdata", 40] 41 42try: 43 import netCDF4 44except ImportError as exc: 45 netCDF4 = None 46 warnings.warn( 47 """\ 48`import netCDF4` failed with the following error: 49 50%s 51 52Please install netcdf4 with `conda install netcdf4` 53If the conda version does not work, uninstall it with `conda uninstall hdf4 hdf5 netcdf4` 54and use `pip install netcdf4`""" 55 % str(exc) 56 ) 57 58 59def _asreader(file, cls): 60 closeit = False 61 if not isinstance(file, cls): 62 file, closeit = cls(file), True 63 return file, closeit 64 65 66def as_ncreader(file): 67 """ 68 Convert file into a NetcdfReader instance. 69 Returns reader, closeit where closeit is set to True 70 if we have to close the file before leaving the procedure. 71 """ 72 return _asreader(file, NetcdfReader) 73 74 75def as_etsfreader(file): 76 """Return an ETSF_Reader. Accepts filename or ETSF_Reader.""" 77 return _asreader(file, ETSF_Reader) 78 79 80class NetcdfReaderError(Exception): 81 """Base error class for NetcdfReader""" 82 83 84class NO_DEFAULT: 85 """Signal that read_value should raise an Error""" 86 87 88class NetcdfReader: 89 """ 90 Wraps and extends netCDF4.Dataset. Read only mode. Supports with statements. 91 92 Additional documentation available at: 93 http://netcdf4-python.googlecode.com/svn/trunk/docs/netCDF4-module.html 94 """ 95 96 Error = NetcdfReaderError 97 98 @requires(netCDF4 is not None, "netCDF4 must be installed to use this class") 99 def __init__(self, path): 100 """Open the Netcdf file specified by path (read mode).""" 101 self.path = os.path.abspath(path) 102 103 try: 104 self.rootgrp = netCDF4.Dataset(self.path, mode="r") 105 except Exception as exc: 106 raise self.Error("In file %s: %s" % (self.path, str(exc))) 107 108 self.ngroups = len(list(self.walk_tree())) 109 110 # Always return non-masked numpy arrays. 111 # Slicing a ncvar returns a MaskedArrray and this is really annoying 112 # because it can lead to unexpected behaviour in e.g. calls to np.matmul! 113 # See also https://github.com/Unidata/netcdf4-python/issues/785 114 self.rootgrp.set_auto_mask(False) 115 116 def __enter__(self): 117 """Activated when used in the with statement.""" 118 return self 119 120 def __exit__(self, type, value, traceback): 121 """Activated at the end of the with statement. It automatically closes the file.""" 122 self.rootgrp.close() 123 124 def close(self): 125 """Close the file.""" 126 try: 127 self.rootgrp.close() 128 except Exception as exc: 129 logger.warning("Exception %s while trying to close %s" % (exc, self.path)) 130 131 def walk_tree(self, top=None): 132 """ 133 Navigate all the groups in the file starting from top. 134 If top is None, the root group is used. 135 """ 136 if top is None: 137 top = self.rootgrp 138 139 values = top.groups.values() 140 yield values 141 for value in top.groups.values(): 142 for children in self.walk_tree(value): 143 yield children 144 145 def print_tree(self): 146 """Print all the groups in the file.""" 147 for children in self.walk_tree(): 148 for child in children: 149 print(child) 150 151 def read_dimvalue(self, dimname, path="/", default=NO_DEFAULT): 152 """ 153 Returns the value of a dimension. 154 155 Args: 156 dimname: Name of the variable 157 path: path to the group. 158 default: return `default` if `dimname` is not present and 159 `default` is not `NO_DEFAULT` else raise self.Error. 160 """ 161 try: 162 dim = self._read_dimensions(dimname, path=path)[0] 163 return len(dim) 164 except self.Error: 165 if default is NO_DEFAULT: 166 raise 167 return default 168 169 def read_varnames(self, path="/"): 170 """List of variable names stored in the group specified by path.""" 171 if path == "/": 172 return self.rootgrp.variables.keys() 173 group = self.path2group[path] 174 return group.variables.keys() 175 176 def read_value(self, varname, path="/", cmode=None, default=NO_DEFAULT): 177 """ 178 Returns the values of variable with name varname in the group specified by path. 179 180 Args: 181 varname: Name of the variable 182 path: path to the group. 183 cmode: if cmode=="c", a complex ndarrays is constructed and returned 184 (netcdf does not provide native support from complex datatype). 185 default: returns default if varname is not present. 186 self.Error is raised if default is set to NO_DEFAULT 187 188 Returns: 189 numpy array if varname represents an array, scalar otherwise. 190 """ 191 try: 192 var = self.read_variable(varname, path=path) 193 except self.Error: 194 if default is NO_DEFAULT: 195 raise 196 return default 197 198 if cmode is None: 199 # scalar or array 200 # getValue is not portable! 201 try: 202 return var.getValue()[0] if not var.shape else var[:] 203 except IndexError: 204 return var.getValue() if not var.shape else var[:] 205 206 assert var.shape[-1] == 2 207 if cmode == "c": 208 return var[..., 0] + 1j * var[..., 1] 209 raise ValueError("Wrong value for cmode %s" % cmode) 210 211 def read_variable(self, varname, path="/"): 212 """Returns the variable with name varname in the group specified by path.""" 213 return self._read_variables(varname, path=path)[0] 214 215 def _read_dimensions(self, *dimnames, **kwargs): 216 path = kwargs.get("path", "/") 217 try: 218 if path == "/": 219 return [self.rootgrp.dimensions[dname] for dname in dimnames] 220 group = self.path2group[path] 221 return [group.dimensions[dname] for dname in dimnames] 222 223 except KeyError: 224 raise self.Error( 225 "In file %s:\nError while reading dimensions: `%s` with kwargs: `%s`" % (self.path, dimnames, kwargs) 226 ) 227 228 def _read_variables(self, *varnames, **kwargs): 229 path = kwargs.get("path", "/") 230 try: 231 if path == "/": 232 return [self.rootgrp.variables[vname] for vname in varnames] 233 group = self.path2group[path] 234 return [group.variables[vname] for vname in varnames] 235 236 except KeyError: 237 raise self.Error( 238 "In file %s:\nError while reading variables: `%s` with kwargs `%s`." % (self.path, varnames, kwargs) 239 ) 240 241 def read_keys(self, keys, dict_cls=AttrDict, path="/"): 242 """ 243 Read a list of variables/dimensions from file. If a key is not present the corresponding 244 entry in the output dictionary is set to None. 245 """ 246 od = dict_cls() 247 for k in keys: 248 try: 249 # Try to read a variable. 250 od[k] = self.read_value(k, path=path) 251 except self.Error: 252 try: 253 # Try to read a dimension. 254 od[k] = self.read_dimvalue(k, path=path) 255 except self.Error: 256 od[k] = None 257 258 return od 259 260 261class ETSF_Reader(NetcdfReader): 262 """ 263 This object reads data from a file written according to the ETSF-IO specifications. 264 265 We assume that the netcdf file contains at least the crystallographic section. 266 """ 267 268 @lazy_property 269 def chemical_symbols(self): 270 """Chemical symbols char [number of atom species][symbol length].""" 271 charr = self.read_value("chemical_symbols") 272 symbols = [] 273 for v in charr: 274 s = "".join(c.decode("utf-8") for c in v) 275 symbols.append(s.strip()) 276 277 return symbols 278 279 def typeidx_from_symbol(self, symbol): 280 """Returns the type index from the chemical symbol. Note python convention.""" 281 return self.chemical_symbols.index(symbol) 282 283 def read_structure(self, cls=Structure): 284 """Returns the crystalline structure stored in the rootgrp.""" 285 return structure_from_ncdata(self, cls=cls) 286 287 def read_abinit_xcfunc(self): 288 """ 289 Read ixc from an Abinit file. Return :class:`XcFunc` object. 290 """ 291 ixc = int(self.read_value("ixc")) 292 return XcFunc.from_abinit_ixc(ixc) 293 294 def read_abinit_hdr(self): 295 """ 296 Read the variables associated to the Abinit header. 297 298 Return :class:`AbinitHeader` 299 """ 300 d = {} 301 for hvar in _HDR_VARIABLES.values(): 302 ncname = hvar.etsf_name if hvar.etsf_name is not None else hvar.name 303 if ncname in self.rootgrp.variables: 304 d[hvar.name] = self.read_value(ncname) 305 elif ncname in self.rootgrp.dimensions: 306 d[hvar.name] = self.read_dimvalue(ncname) 307 else: 308 raise ValueError("Cannot find `%s` in `%s`" % (ncname, self.path)) 309 # Convert scalars to (well) scalars. 310 if hasattr(d[hvar.name], "shape") and not d[hvar.name].shape: 311 d[hvar.name] = np.asarray(d[hvar.name]).item() 312 if hvar.name in ("title", "md5_pseudos", "codvsn"): 313 # Convert array of numpy bytes to list of strings 314 if hvar.name == "codvsn": 315 d[hvar.name] = "".join(bs.decode("utf-8").strip() for bs in d[hvar.name]) 316 else: 317 d[hvar.name] = ["".join(bs.decode("utf-8") for bs in astr).strip() for astr in d[hvar.name]] 318 319 return AbinitHeader(d) 320 321 322def structure_from_ncdata(ncdata, site_properties=None, cls=Structure): 323 """ 324 Reads and returns a pymatgen structure from a NetCDF file 325 containing crystallographic data in the ETSF-IO format. 326 327 Args: 328 ncdata: filename or NetcdfReader instance. 329 site_properties: Dictionary with site properties. 330 cls: The Structure class to instanciate. 331 """ 332 ncdata, closeit = as_ncreader(ncdata) 333 334 # TODO check whether atomic units are used 335 lattice = ArrayWithUnit(ncdata.read_value("primitive_vectors"), "bohr").to("ang") 336 337 red_coords = ncdata.read_value("reduced_atom_positions") 338 natom = len(red_coords) 339 340 znucl_type = ncdata.read_value("atomic_numbers") 341 342 # type_atom[0:natom] --> index Between 1 and number of atom species 343 type_atom = ncdata.read_value("atom_species") 344 345 # Fortran to C index and float --> int conversion. 346 species = natom * [None] 347 for atom in range(natom): 348 type_idx = type_atom[atom] - 1 349 species[atom] = int(znucl_type[type_idx]) 350 351 d = {} 352 if site_properties is not None: 353 for prop in site_properties: 354 d[prop] = ncdata.read_value(prop) 355 356 structure = cls(lattice, species, red_coords, site_properties=d) 357 358 # Quick and dirty hack. 359 # I need an abipy structure since I need to_abivars and other methods. 360 try: 361 from abipy.core.structure import Structure as AbipyStructure 362 363 structure.__class__ = AbipyStructure 364 except ImportError: 365 pass 366 367 if closeit: 368 ncdata.close() 369 370 return structure 371 372 373class _H: 374 __slots__ = ["name", "doc", "etsf_name"] 375 376 def __init__(self, name, doc, etsf_name=None): 377 self.name, self.doc, self.etsf_name = name, doc, etsf_name 378 379 380_HDR_VARIABLES = ( 381 # Scalars 382 _H("bantot", "total number of bands (sum of nband on all kpts and spins)"), 383 _H("date", "starting date"), 384 _H("headform", "format of the header"), 385 _H("intxc", "input variable"), 386 _H("ixc", "input variable"), 387 _H("mband", "maxval(hdr%nband)", etsf_name="max_number_of_states"), 388 _H("natom", "input variable", etsf_name="number_of_atoms"), 389 _H("nkpt", "input variable", etsf_name="number_of_kpoints"), 390 _H("npsp", "input variable"), 391 _H("nspden", "input variable", etsf_name="number_of_components"), 392 _H("nspinor", "input variable", etsf_name="number_of_spinor_components"), 393 _H("nsppol", "input variable", etsf_name="number_of_spins"), 394 _H("nsym", "input variable", etsf_name="number_of_symmetry_operations"), 395 _H("ntypat", "input variable", etsf_name="number_of_atom_species"), 396 _H("occopt", "input variable"), 397 _H("pertcase", "the index of the perturbation, 0 if GS calculation"), 398 _H("usepaw", "input variable (0=norm-conserving psps, 1=paw)"), 399 _H("usewvl", "input variable (0=plane-waves, 1=wavelets)"), 400 _H("kptopt", "input variable (defines symmetries used for k-point sampling)"), 401 _H("pawcpxocc", "input variable"), 402 _H( 403 "nshiftk_orig", 404 "original number of shifts given in input (changed in inkpts, the actual value is nshiftk)", 405 ), 406 _H("nshiftk", "number of shifts after inkpts."), 407 _H("icoulomb", "input variable."), 408 _H("ecut", "input variable", etsf_name="kinetic_energy_cutoff"), 409 _H("ecutdg", "input variable (ecut for NC psps, pawecutdg for paw)"), 410 _H("ecutsm", "input variable"), 411 _H("ecut_eff", "ecut*dilatmx**2 (dilatmx is an input variable)"), 412 _H("etot", "EVOLVING variable"), 413 _H("fermie", "EVOLVING variable", etsf_name="fermi_energy"), 414 _H("residm", "EVOLVING variable"), 415 _H("stmbias", "input variable"), 416 _H("tphysel", "input variable"), 417 _H("tsmear", "input variable"), 418 _H("nelect", "number of electrons (computed from pseudos and charge)"), 419 _H("charge", "input variable"), 420 # Arrays 421 _H("qptn", "qptn(3) the wavevector, in case of a perturbation"), 422 # _H("rprimd", "rprimd(3,3) EVOLVING variables", etsf_name="primitive_vectors"), 423 # _H(ngfft, "ngfft(3) input variable", number_of_grid_points_vector1" 424 # _H("nwvlarr", "nwvlarr(2) the number of wavelets for each resolution.", etsf_name="number_of_wavelets"), 425 _H("kptrlatt_orig", "kptrlatt_orig(3,3) Original kptrlatt"), 426 _H("kptrlatt", "kptrlatt(3,3) kptrlatt after inkpts."), 427 _H("istwfk", "input variable istwfk(nkpt)"), 428 _H("lmn_size", "lmn_size(npsp) from psps"), 429 _H("nband", "input variable nband(nkpt*nsppol)", etsf_name="number_of_states"), 430 _H( 431 "npwarr", 432 "npwarr(nkpt) array holding npw for each k point", 433 etsf_name="number_of_coefficients", 434 ), 435 _H("pspcod", "pscod(npsp) from psps"), 436 _H("pspdat", "psdat(npsp) from psps"), 437 _H("pspso", "pspso(npsp) from psps"), 438 _H("pspxc", "pspxc(npsp) from psps"), 439 _H("so_psp", "input variable so_psp(npsp)"), 440 _H("symafm", "input variable symafm(nsym)"), 441 # _H(symrel="input variable symrel(3,3,nsym)", etsf_name="reduced_symmetry_matrices"), 442 _H("typat", "input variable typat(natom)", etsf_name="atom_species"), 443 _H( 444 "kptns", 445 "input variable kptns(nkpt, 3)", 446 etsf_name="reduced_coordinates_of_kpoints", 447 ), 448 _H("occ", "EVOLVING variable occ(mband, nkpt, nsppol)", etsf_name="occupations"), 449 _H( 450 "tnons", 451 "input variable tnons(nsym, 3)", 452 etsf_name="reduced_symmetry_translations", 453 ), 454 _H("wtk", "weight of kpoints wtk(nkpt)", etsf_name="kpoint_weights"), 455 _H("shiftk_orig", "original shifts given in input (changed in inkpts)."), 456 _H("shiftk", "shiftk(3,nshiftk), shiftks after inkpts"), 457 _H("amu", "amu(ntypat) ! EVOLVING variable"), 458 # _H("xred", "EVOLVING variable xred(3,natom)", etsf_name="reduced_atom_positions"), 459 _H("zionpsp", "zionpsp(npsp) from psps"), 460 _H( 461 "znuclpsp", 462 "znuclpsp(npsp) from psps. Note the difference between (znucl|znucltypat) and znuclpsp", 463 ), 464 _H("znucltypat", "znucltypat(ntypat) from alchemy", etsf_name="atomic_numbers"), 465 _H("codvsn", "version of the code"), 466 _H("title", "title(npsp) from psps"), 467 _H( 468 "md5_pseudos", 469 "md5pseudos(npsp), md5 checksums associated to pseudos (read from file)", 470 ), 471 # _H(type(pawrhoij_type), allocatable :: pawrhoij(:) ! EVOLVING variable, only for paw 472) 473_HDR_VARIABLES = OrderedDict([(h.name, h) for h in _HDR_VARIABLES]) # type: ignore 474 475 476class AbinitHeader(AttrDict): 477 """Stores the values reported in the Abinit header.""" 478 479 # def __init__(self, *args, **kwargs): 480 # super().__init__(*args, **kwargs) 481 # for k, v in self.items(): 482 # v.__doc__ = _HDR_VARIABLES[k].doc 483 484 def __str__(self): 485 return self.to_string() 486 487 def to_string(self, verbose=0, title=None, **kwargs): 488 """ 489 String representation. kwargs are passed to `pprint.pformat`. 490 491 Args: 492 verbose: Verbosity level 493 title: Title string. 494 """ 495 from pprint import pformat 496 497 s = pformat(self, **kwargs) 498 if title is not None: 499 return "\n".join([marquee(title, mark="="), s]) 500 return s 501