1""" 2This module defines a simplified interface for generating ABINIT input files. 3Note that not all the features of Abinit are supported by BasicAbinitInput. 4For a more comprehensive implementation, use the AbinitInput object provided by AbiPy. 5""" 6 7import abc 8import copy 9import json 10import logging 11import os 12from collections import OrderedDict, namedtuple 13from collections.abc import Mapping, MutableMapping 14from enum import Enum 15 16import numpy as np 17from monty.collections import AttrDict 18from monty.json import MSONable 19from monty.string import is_string, list_strings 20 21from pymatgen.core.structure import Structure 22from pymatgen.io.abinit import abiobjects as aobj 23from pymatgen.io.abinit.pseudos import Pseudo, PseudoTable 24from pymatgen.io.abinit.variable import InputVariable 25from pymatgen.util.serialization import pmg_serialize 26 27logger = logging.getLogger(__file__) 28 29 30# List of Abinit variables used to specify the structure. 31# This variables should not be passed to set_vars since 32# they will be generated with structure.to_abivars() 33GEOVARS = set( 34 [ 35 "acell", 36 "rprim", 37 "rprimd" "angdeg", 38 "xred", 39 "xcart", 40 "xangst", 41 "znucl", 42 "typat", 43 "ntypat", 44 "natom", 45 ] 46) 47 48# Variables defining tolerances (used in pop_tolerances) 49_TOLVARS = set( 50 [ 51 "toldfe", 52 "tolvrs", 53 "tolwfr", 54 "tolrff", 55 "toldff", 56 "tolimg", 57 "tolmxf", 58 "tolrde", 59 ] 60) 61 62# Variables defining tolerances for the SCF cycle that are mutally exclusive 63_TOLVARS_SCF = set( 64 [ 65 "toldfe", 66 "tolvrs", 67 "tolwfr", 68 "tolrff", 69 "toldff", 70 ] 71) 72 73# Variables determining if data files should be read in input 74_IRDVARS = set( 75 [ 76 "irdbseig", 77 "irdbsreso", 78 "irdhaydock", 79 "irdddk", 80 "irdden", 81 "ird1den", 82 "irdqps", 83 "irdkss", 84 "irdscr", 85 "irdsuscep", 86 "irdvdw", 87 "irdwfk", 88 "irdwfkfine", 89 "irdwfq", 90 "ird1wf", 91 ] 92) 93 94# Name of the (default) tolerance used by the runlevels. 95_runl2tolname = { 96 "scf": "tolvrs", 97 "nscf": "tolwfr", 98 "dfpt": "toldfe", # ? 99 "screening": "toldfe", # dummy 100 "sigma": "toldfe", # dummy 101 "bse": "toldfe", # ? 102 "relax": "tolrff", 103} 104 105# Tolerances for the different levels of accuracy. 106 107T = namedtuple("T", "low normal high") 108_tolerances = { 109 "toldfe": T(1.0e-7, 1.0e-8, 1.0e-9), 110 "tolvrs": T(1.0e-7, 1.0e-8, 1.0e-9), 111 "tolwfr": T(1.0e-15, 1.0e-17, 1.0e-19), 112 "tolrff": T(0.04, 0.02, 0.01), 113} 114del T 115 116 117# Default values used if user does not specify them 118_DEFAULTS = dict( 119 kppa=1000, 120) 121 122 123def as_structure(obj): 124 """ 125 Convert obj into a Structure. Accepts: 126 127 - Structure object. 128 - Filename 129 - Dictionaries (MSONable format or dictionaries with abinit variables). 130 """ 131 if isinstance(obj, Structure): 132 return obj 133 134 if is_string(obj): 135 return Structure.from_file(obj) 136 137 if isinstance(obj, Mapping): 138 if "@module" in obj: 139 return Structure.from_dict(obj) 140 return aobj.structure_from_abivars(cls=None, **obj) 141 142 raise TypeError("Don't know how to convert %s into a structure" % type(obj)) 143 144 145class ShiftMode(Enum): 146 """ 147 Class defining the mode to be used for the shifts. 148 G: Gamma centered 149 M: Monkhorst-Pack ((0.5, 0.5, 0.5)) 150 S: Symmetric. Respects the chksymbreak with multiple shifts 151 O: OneSymmetric. Respects the chksymbreak with a single shift (as in 'S' if a single shift is given, gamma 152 centered otherwise. 153 """ 154 155 GammaCentered = "G" 156 MonkhorstPack = "M" 157 Symmetric = "S" 158 OneSymmetric = "O" 159 160 @classmethod 161 def from_object(cls, obj): 162 """ 163 Returns an instance of ShiftMode based on the type of object passed. Converts strings to ShiftMode depending 164 on the iniital letter of the string. G for GammaCenterd, M for MonkhorstPack, 165 S for Symmetric, O for OneSymmetric. 166 Case insensitive. 167 """ 168 if isinstance(obj, cls): 169 return obj 170 if is_string(obj): 171 return cls(obj[0].upper()) 172 raise TypeError("The object provided is not handled: type %s" % type(obj)) 173 174 175def _stopping_criterion(runlevel, accuracy): 176 """Return the stopping criterion for this runlevel with the given accuracy.""" 177 tolname = _runl2tolname[runlevel] 178 return {tolname: getattr(_tolerances[tolname], accuracy)} 179 180 181def _find_ecut_pawecutdg(ecut, pawecutdg, pseudos, accuracy): 182 """Return a |AttrDict| with the value of ``ecut`` and ``pawecutdg``.""" 183 # Get ecut and pawecutdg from the pseudo hints. 184 if ecut is None or (pawecutdg is None and any(p.ispaw for p in pseudos)): 185 has_hints = all(p.has_hints for p in pseudos) 186 187 if ecut is None: 188 if has_hints: 189 ecut = max(p.hint_for_accuracy(accuracy).ecut for p in pseudos) 190 else: 191 raise RuntimeError("ecut is None but pseudos do not provide hints for ecut") 192 193 if pawecutdg is None and any(p.ispaw for p in pseudos): 194 if has_hints: 195 pawecutdg = max(p.hint_for_accuracy(accuracy).pawecutdg for p in pseudos) 196 else: 197 raise RuntimeError("pawecutdg is None but pseudos do not provide hints") 198 199 return AttrDict(ecut=ecut, pawecutdg=pawecutdg) 200 201 202def _find_scf_nband(structure, pseudos, electrons, spinat=None): 203 """Find the value of ``nband``.""" 204 if electrons.nband is not None: 205 return electrons.nband 206 207 nsppol, smearing = electrons.nsppol, electrons.smearing 208 209 # Number of valence electrons including possible extra charge 210 nval = num_valence_electrons(structure, pseudos) 211 nval -= electrons.charge 212 213 # First guess (semiconductors) 214 nband = nval // 2 215 216 # TODO: Find better algorithm 217 # If nband is too small we may kill the job, increase nband and restart 218 # but this change could cause problems in the other steps of the calculation 219 # if the change is not propagated e.g. phonons in metals. 220 if smearing: 221 # metallic occupation 222 nband = max(np.ceil(nband * 1.2), nband + 10) 223 else: 224 nband = max(np.ceil(nband * 1.1), nband + 4) 225 226 # Increase number of bands based on the starting magnetization 227 if nsppol == 2 and spinat is not None: 228 nband += np.ceil(max(np.sum(spinat, axis=0)) / 2.0) 229 230 # Force even nband (easier to divide among procs, mandatory if nspinor == 2) 231 nband += nband % 2 232 return int(nband) 233 234 235def _get_shifts(shift_mode, structure): 236 """ 237 Gives the shifts based on the selected shift mode and on the symmetry of the structure. 238 G: Gamma centered 239 M: Monkhorst-Pack ((0.5, 0.5, 0.5)) 240 S: Symmetric. Respects the chksymbreak with multiple shifts 241 O: OneSymmetric. Respects the chksymbreak with a single shift (as in 'S' if a single shift is given, gamma 242 centered otherwise. 243 244 Note: for some cases (e.g. body centered tetragonal), both the Symmetric and OneSymmetric may fail to satisfy the 245 ``chksymbreak`` condition (Abinit input variable). 246 """ 247 if shift_mode == ShiftMode.GammaCentered: 248 return ((0, 0, 0),) 249 if shift_mode == ShiftMode.MonkhorstPack: 250 return ((0.5, 0.5, 0.5),) 251 if shift_mode == ShiftMode.Symmetric: 252 return calc_shiftk(structure) 253 if shift_mode == ShiftMode.OneSymmetric: 254 shifts = calc_shiftk(structure) 255 if len(shifts) == 1: 256 return shifts 257 return ((0, 0, 0),) 258 259 raise ValueError("invalid shift_mode: `%s`" % str(shift_mode)) 260 261 262def gs_input( 263 structure, 264 pseudos, 265 kppa=None, 266 ecut=None, 267 pawecutdg=None, 268 scf_nband=None, 269 accuracy="normal", 270 spin_mode="polarized", 271 smearing="fermi_dirac:0.1 eV", 272 charge=0.0, 273 scf_algorithm=None, 274): 275 """ 276 Returns a |BasicAbinitInput| for ground-state calculation. 277 278 Args: 279 structure: |Structure| object. 280 pseudos: List of filenames or list of |Pseudo| objects or |PseudoTable| object. 281 kppa: Defines the sampling used for the SCF run. Defaults to 1000 if not given. 282 ecut: cutoff energy in Ha (if None, ecut is initialized from the pseudos according to accuracy) 283 pawecutdg: cutoff energy in Ha for PAW double-grid (if None, pawecutdg is initialized from the pseudos 284 according to accuracy) 285 scf_nband: Number of bands for SCF run. If scf_nband is None, nband is automatically initialized 286 from the list of pseudos, the structure and the smearing option. 287 accuracy: Accuracy of the calculation. 288 spin_mode: Spin polarization. 289 smearing: Smearing technique. 290 charge: Electronic charge added to the unit cell. 291 scf_algorithm: Algorithm used for solving of the SCF cycle. 292 """ 293 multi = ebands_input( 294 structure, 295 pseudos, 296 kppa=kppa, 297 ndivsm=0, 298 ecut=ecut, 299 pawecutdg=pawecutdg, 300 scf_nband=scf_nband, 301 accuracy=accuracy, 302 spin_mode=spin_mode, 303 smearing=smearing, 304 charge=charge, 305 scf_algorithm=scf_algorithm, 306 ) 307 308 return multi[0] 309 310 311def ebands_input( 312 structure, 313 pseudos, 314 kppa=None, 315 nscf_nband=None, 316 ndivsm=15, 317 ecut=None, 318 pawecutdg=None, 319 scf_nband=None, 320 accuracy="normal", 321 spin_mode="polarized", 322 smearing="fermi_dirac:0.1 eV", 323 charge=0.0, 324 scf_algorithm=None, 325 dos_kppa=None, 326): 327 """ 328 Returns a |BasicMultiDataset| object for band structure calculations. 329 330 Args: 331 structure: |Structure| object. 332 pseudos: List of filenames or list of |Pseudo| objects or |PseudoTable| object. 333 kppa: Defines the sampling used for the SCF run. Defaults to 1000 if not given. 334 nscf_nband: Number of bands included in the NSCF run. Set to scf_nband + 10 if None. 335 ndivsm: Number of divisions used to sample the smallest segment of the k-path. 336 if 0, only the GS input is returned in multi[0]. 337 ecut: cutoff energy in Ha (if None, ecut is initialized from the pseudos according to accuracy) 338 pawecutdg: cutoff energy in Ha for PAW double-grid (if None, pawecutdg is initialized from the pseudos 339 according to accuracy) 340 scf_nband: Number of bands for SCF run. If scf_nband is None, nband is automatically initialized 341 from the list of pseudos, the structure and the smearing option. 342 accuracy: Accuracy of the calculation. 343 spin_mode: Spin polarization. 344 smearing: Smearing technique. 345 charge: Electronic charge added to the unit cell. 346 scf_algorithm: Algorithm used for solving of the SCF cycle. 347 dos_kppa: Scalar or List of integers with the number of k-points per atom 348 to be used for the computation of the DOS (None if DOS is not wanted). 349 """ 350 structure = as_structure(structure) 351 352 if dos_kppa is not None and not isinstance(dos_kppa, (list, tuple)): 353 dos_kppa = [dos_kppa] 354 355 multi = BasicMultiDataset(structure, pseudos, ndtset=2 if dos_kppa is None else 2 + len(dos_kppa)) 356 357 # Set the cutoff energies. 358 multi.set_vars(_find_ecut_pawecutdg(ecut, pawecutdg, multi.pseudos, accuracy)) 359 360 # SCF calculation. 361 kppa = _DEFAULTS.get("kppa") if kppa is None else kppa 362 scf_ksampling = aobj.KSampling.automatic_density(structure, kppa, chksymbreak=0) 363 scf_electrons = aobj.Electrons( 364 spin_mode=spin_mode, 365 smearing=smearing, 366 algorithm=scf_algorithm, 367 charge=charge, 368 nband=scf_nband, 369 fband=None, 370 ) 371 372 if scf_electrons.nband is None: 373 scf_electrons.nband = _find_scf_nband(structure, multi.pseudos, scf_electrons, multi[0].get("spinat", None)) 374 375 multi[0].set_vars(scf_ksampling.to_abivars()) 376 multi[0].set_vars(scf_electrons.to_abivars()) 377 multi[0].set_vars(_stopping_criterion("scf", accuracy)) 378 if ndivsm == 0: 379 return multi 380 381 # Band structure calculation. 382 nscf_ksampling = aobj.KSampling.path_from_structure(ndivsm, structure) 383 nscf_nband = scf_electrons.nband + 10 if nscf_nband is None else nscf_nband 384 nscf_electrons = aobj.Electrons( 385 spin_mode=spin_mode, 386 smearing=smearing, 387 algorithm={"iscf": -2}, 388 charge=charge, 389 nband=nscf_nband, 390 fband=None, 391 ) 392 393 multi[1].set_vars(nscf_ksampling.to_abivars()) 394 multi[1].set_vars(nscf_electrons.to_abivars()) 395 multi[1].set_vars(_stopping_criterion("nscf", accuracy)) 396 397 # DOS calculation with different values of kppa. 398 if dos_kppa is not None: 399 for i, kppa_ in enumerate(dos_kppa): 400 dos_ksampling = aobj.KSampling.automatic_density(structure, kppa_, chksymbreak=0) 401 # dos_ksampling = aobj.KSampling.monkhorst(dos_ngkpt, shiftk=dos_shiftk, chksymbreak=0) 402 dos_electrons = aobj.Electrons( 403 spin_mode=spin_mode, 404 smearing=smearing, 405 algorithm={"iscf": -2}, 406 charge=charge, 407 nband=nscf_nband, 408 ) 409 dt = 2 + i 410 multi[dt].set_vars(dos_ksampling.to_abivars()) 411 multi[dt].set_vars(dos_electrons.to_abivars()) 412 multi[dt].set_vars(_stopping_criterion("nscf", accuracy)) 413 414 return multi 415 416 417def ion_ioncell_relax_input( 418 structure, 419 pseudos, 420 kppa=None, 421 nband=None, 422 ecut=None, 423 pawecutdg=None, 424 accuracy="normal", 425 spin_mode="polarized", 426 smearing="fermi_dirac:0.1 eV", 427 charge=0.0, 428 scf_algorithm=None, 429 shift_mode="Monkhorst-pack", 430): 431 """ 432 Returns a |BasicMultiDataset| for a structural relaxation. The first dataset optmizes the 433 atomic positions at fixed unit cell. The second datasets optimizes both ions and unit cell parameters. 434 435 Args: 436 structure: |Structure| object. 437 pseudos: List of filenames or list of |Pseudo| objects or |PseudoTable| object. 438 kppa: Defines the sampling used for the Brillouin zone. 439 nband: Number of bands included in the SCF run. 440 accuracy: Accuracy of the calculation. 441 spin_mode: Spin polarization. 442 smearing: Smearing technique. 443 charge: Electronic charge added to the unit cell. 444 scf_algorithm: Algorithm used for the solution of the SCF cycle. 445 """ 446 structure = as_structure(structure) 447 multi = BasicMultiDataset(structure, pseudos, ndtset=2) 448 449 # Set the cutoff energies. 450 multi.set_vars(_find_ecut_pawecutdg(ecut, pawecutdg, multi.pseudos, accuracy)) 451 452 kppa = _DEFAULTS.get("kppa") if kppa is None else kppa 453 454 shift_mode = ShiftMode.from_object(shift_mode) 455 shifts = _get_shifts(shift_mode, structure) 456 ksampling = aobj.KSampling.automatic_density(structure, kppa, chksymbreak=0, shifts=shifts) 457 electrons = aobj.Electrons( 458 spin_mode=spin_mode, 459 smearing=smearing, 460 algorithm=scf_algorithm, 461 charge=charge, 462 nband=nband, 463 fband=None, 464 ) 465 466 if electrons.nband is None: 467 electrons.nband = _find_scf_nband(structure, multi.pseudos, electrons, multi[0].get("spinat", None)) 468 469 ion_relax = aobj.RelaxationMethod.atoms_only(atoms_constraints=None) 470 ioncell_relax = aobj.RelaxationMethod.atoms_and_cell(atoms_constraints=None) 471 472 multi.set_vars(electrons.to_abivars()) 473 multi.set_vars(ksampling.to_abivars()) 474 475 multi[0].set_vars(ion_relax.to_abivars()) 476 multi[0].set_vars(_stopping_criterion("relax", accuracy)) 477 478 multi[1].set_vars(ioncell_relax.to_abivars()) 479 multi[1].set_vars(_stopping_criterion("relax", accuracy)) 480 481 return multi 482 483 484def calc_shiftk(structure, symprec=0.01, angle_tolerance=5): 485 """ 486 Find the values of ``shiftk`` and ``nshiftk`` appropriated for the sampling of the Brillouin zone. 487 488 When the primitive vectors of the lattice do NOT form a FCC or a BCC lattice, 489 the usual (shifted) Monkhorst-Pack grids are formed by using nshiftk=1 and shiftk 0.5 0.5 0.5 . 490 This is often the preferred k point sampling. For a non-shifted Monkhorst-Pack grid, 491 use `nshiftk=1` and `shiftk 0.0 0.0 0.0`, but there is little reason to do that. 492 493 When the primitive vectors of the lattice form a FCC lattice, with rprim:: 494 495 0.0 0.5 0.5 496 0.5 0.0 0.5 497 0.5 0.5 0.0 498 499 the (very efficient) usual Monkhorst-Pack sampling will be generated by using nshiftk= 4 and shiftk:: 500 501 0.5 0.5 0.5 502 0.5 0.0 0.0 503 0.0 0.5 0.0 504 0.0 0.0 0.5 505 506 When the primitive vectors of the lattice form a BCC lattice, with rprim:: 507 508 -0.5 0.5 0.5 509 0.5 -0.5 0.5 510 0.5 0.5 -0.5 511 512 the usual Monkhorst-Pack sampling will be generated by using nshiftk= 2 and shiftk:: 513 514 0.25 0.25 0.25 515 -0.25 -0.25 -0.25 516 517 However, the simple sampling nshiftk=1 and shiftk 0.5 0.5 0.5 is excellent. 518 519 For hexagonal lattices with hexagonal axes, e.g. rprim:: 520 521 1.0 0.0 0.0 522 -0.5 sqrt(3)/2 0.0 523 0.0 0.0 1.0 524 525 one can use nshiftk= 1 and shiftk 0.0 0.0 0.5 526 In rhombohedral axes, e.g. using angdeg 3*60., this corresponds to shiftk 0.5 0.5 0.5, 527 to keep the shift along the symmetry axis. 528 529 Returns: 530 Suggested value of shiftk. 531 """ 532 # Find lattice type. 533 from pymatgen.symmetry.analyzer import SpacegroupAnalyzer 534 535 sym = SpacegroupAnalyzer(structure, symprec=symprec, angle_tolerance=angle_tolerance) 536 lattice_type, spg_symbol = sym.get_lattice_type(), sym.get_space_group_symbol() 537 538 # Check if the cell is primitive 539 is_primitive = len(sym.find_primitive()) == len(structure) 540 541 # Generate the appropriate set of shifts. 542 shiftk = None 543 544 if is_primitive: 545 if lattice_type == "cubic": 546 if "F" in spg_symbol: 547 # FCC 548 shiftk = [0.5, 0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.5] 549 550 elif "I" in spg_symbol: 551 # BCC 552 shiftk = [0.25, 0.25, 0.25, -0.25, -0.25, -0.25] 553 # shiftk = [0.5, 0.5, 05]) 554 555 elif lattice_type == "hexagonal": 556 # Find the hexagonal axis and set the shift along it. 557 for i, angle in enumerate(structure.lattice.angles): 558 if abs(angle - 120) < 1.0: 559 j = (i + 1) % 3 560 k = (i + 2) % 3 561 hex_ax = [ax for ax in range(3) if ax not in [j, k]][0] 562 break 563 else: 564 raise ValueError("Cannot find hexagonal axis") 565 566 shiftk = [0.0, 0.0, 0.0] 567 shiftk[hex_ax] = 0.5 568 569 elif lattice_type == "tetragonal": 570 if "I" in spg_symbol: 571 # BCT 572 shiftk = [0.25, 0.25, 0.25, -0.25, -0.25, -0.25] 573 574 if shiftk is None: 575 # Use default value. 576 shiftk = [0.5, 0.5, 0.5] 577 578 return np.reshape(shiftk, (-1, 3)) 579 580 581def num_valence_electrons(structure, pseudos): 582 """ 583 Returns the number of valence electrons. 584 585 Args: 586 pseudos: List of |Pseudo| objects or list of filenames. 587 """ 588 nval, table = 0, PseudoTable.as_table(pseudos) 589 for site in structure: 590 pseudo = table.pseudo_with_symbol(site.specie.symbol) 591 nval += pseudo.Z_val 592 593 return int(nval) if int(nval) == nval else nval 594 595 596class AbstractInput(MutableMapping, metaclass=abc.ABCMeta): 597 """ 598 Abstract class defining the methods that must be implemented by Input objects. 599 """ 600 601 # ABC protocol: __delitem__, __getitem__, __iter__, __len__, __setitem__ 602 def __delitem__(self, key): 603 return self.vars.__delitem__(key) 604 605 def __getitem__(self, key): 606 return self.vars.__getitem__(key) 607 608 def __iter__(self): 609 return self.vars.__iter__() 610 611 def __len__(self): 612 return len(self.vars) 613 614 def __setitem__(self, key, value): 615 self._check_varname(key) 616 return self.vars.__setitem__(key, value) 617 618 def __repr__(self): 619 return "<%s at %s>" % (self.__class__.__name__, id(self)) 620 621 def __str__(self): 622 return self.to_string() 623 624 def write(self, filepath="run.abi"): 625 """ 626 Write the input file to file to ``filepath``. 627 """ 628 dirname = os.path.dirname(os.path.abspath(filepath)) 629 if not os.path.exists(dirname): 630 os.makedirs(dirname) 631 632 # Write the input file. 633 with open(filepath, "wt") as fh: 634 fh.write(str(self)) 635 636 def deepcopy(self): 637 """Deep copy of the input.""" 638 return copy.deepcopy(self) 639 640 def set_vars(self, *args, **kwargs): 641 """ 642 Set the value of the variables. 643 Return dict with the variables added to the input. 644 645 Example: 646 647 input.set_vars(ecut=10, ionmov=3) 648 """ 649 kwargs.update(dict(*args)) 650 for varname, varvalue in kwargs.items(): 651 self[varname] = varvalue 652 return kwargs 653 654 def set_vars_ifnotin(self, *args, **kwargs): 655 """ 656 Set the value of the variables but only if the variable is not already present. 657 Return dict with the variables added to the input. 658 659 Example: 660 661 input.set_vars(ecut=10, ionmov=3) 662 """ 663 kwargs.update(dict(*args)) 664 added = {} 665 for varname, varvalue in kwargs.items(): 666 if varname not in self: 667 self[varname] = varvalue 668 added[varname] = varvalue 669 return added 670 671 def pop_vars(self, keys): 672 """ 673 Remove the variables listed in keys. 674 Return dictionary with the variables that have been removed. 675 Unlike remove_vars, no exception is raised if the variables are not in the input. 676 677 Args: 678 keys: string or list of strings with variable names. 679 680 Example: 681 inp.pop_vars(["ionmov", "optcell", "ntime", "dilatmx"]) 682 """ 683 return self.remove_vars(keys, strict=False) 684 685 def remove_vars(self, keys, strict=True): 686 """ 687 Remove the variables listed in keys. 688 Return dictionary with the variables that have been removed. 689 690 Args: 691 keys: string or list of strings with variable names. 692 strict: If True, KeyError is raised if at least one variable is not present. 693 """ 694 removed = {} 695 for key in list_strings(keys): 696 if strict and key not in self: 697 raise KeyError("key: %s not in self:\n %s" % (key, list(self.keys()))) 698 if key in self: 699 removed[key] = self.pop(key) 700 701 return removed 702 703 @abc.abstractproperty 704 def vars(self): 705 """Dictionary with the input variables. Used to implement dict-like interface.""" 706 707 @abc.abstractmethod 708 def _check_varname(self, key): 709 """Check if key is a valid name. Raise self.Error if not valid.""" 710 711 @abc.abstractmethod 712 def to_string(self): 713 """Returns a string with the input.""" 714 715 716class BasicAbinitInputError(Exception): 717 """Base error class for exceptions raised by ``BasicAbinitInput``.""" 718 719 720class BasicAbinitInput(AbstractInput, MSONable): 721 """ 722 This object stores the ABINIT variables for a single dataset. 723 """ 724 725 Error = BasicAbinitInputError 726 727 def __init__( 728 self, 729 structure, 730 pseudos, 731 pseudo_dir=None, 732 comment=None, 733 abi_args=None, 734 abi_kwargs=None, 735 ): 736 """ 737 Args: 738 structure: Parameters defining the crystalline structure. Accepts |Structure| object 739 file with structure (CIF, netcdf file, ...) or dictionary with ABINIT geo variables. 740 pseudos: Pseudopotentials to be used for the calculation. Accepts: string or list of strings 741 with the name of the pseudopotential files, list of |Pseudo| objects 742 or |PseudoTable| object. 743 pseudo_dir: Name of the directory where the pseudopotential files are located. 744 ndtset: Number of datasets. 745 comment: Optional string with a comment that will be placed at the beginning of the file. 746 abi_args: list of tuples (key, value) with the initial set of variables. Default: Empty 747 abi_kwargs: Dictionary with the initial set of variables. Default: Empty 748 """ 749 # Internal dict with variables. we use an ordered dict so that 750 # variables will be likely grouped by `topics` when we fill the input. 751 abi_args = [] if abi_args is None else abi_args 752 for key, value in abi_args: 753 self._check_varname(key) 754 755 abi_kwargs = {} if abi_kwargs is None else abi_kwargs 756 for key in abi_kwargs: 757 self._check_varname(key) 758 759 args = list(abi_args)[:] 760 args.extend(list(abi_kwargs.items())) 761 762 self._vars = OrderedDict(args) 763 self.set_structure(structure) 764 765 if pseudo_dir is not None: 766 pseudo_dir = os.path.abspath(pseudo_dir) 767 if not os.path.exists(pseudo_dir): 768 raise self.Error("Directory %s does not exist" % pseudo_dir) 769 pseudos = [os.path.join(pseudo_dir, p) for p in list_strings(pseudos)] 770 771 try: 772 self._pseudos = PseudoTable.as_table(pseudos).get_pseudos_for_structure(self.structure) 773 except ValueError as exc: 774 raise self.Error(str(exc)) 775 776 if comment is not None: 777 self.set_comment(comment) 778 779 @pmg_serialize 780 def as_dict(self): 781 """ 782 JSON interface used in pymatgen for easier serialization. 783 """ 784 # Use a list of (key, value) to serialize the OrderedDict 785 abi_args = [] 786 for key, value in self.items(): 787 if isinstance(value, np.ndarray): 788 value = value.tolist() 789 abi_args.append((key, value)) 790 791 return dict( 792 structure=self.structure.as_dict(), 793 pseudos=[p.as_dict() for p in self.pseudos], 794 comment=self.comment, 795 abi_args=abi_args, 796 ) 797 798 @property 799 def vars(self): 800 """Dictionary with variables.""" 801 return self._vars 802 803 @classmethod 804 def from_dict(cls, d): 805 """ 806 JSON interface used in pymatgen for easier serialization. 807 """ 808 pseudos = [Pseudo.from_file(p["filepath"]) for p in d["pseudos"]] 809 return cls(d["structure"], pseudos, comment=d["comment"], abi_args=d["abi_args"]) 810 811 def add_abiobjects(self, *abi_objects): 812 """ 813 This function receive a list of ``AbiVarable`` objects and add 814 the corresponding variables to the input. 815 """ 816 d = {} 817 for obj in abi_objects: 818 if not hasattr(obj, "to_abivars"): 819 raise TypeError("type %s: %s does not have `to_abivars` method" % (type(obj), repr(obj))) 820 d.update(self.set_vars(obj.to_abivars())) 821 return d 822 823 def __setitem__(self, key, value): 824 if key in _TOLVARS_SCF and hasattr(self, "_vars") and any(t in self._vars and t != key for t in _TOLVARS_SCF): 825 logger.info( 826 "Replacing previously set tolerance variable: {0}.".format(self.remove_vars(_TOLVARS_SCF, strict=False)) 827 ) 828 829 return super().__setitem__(key, value) 830 831 def _check_varname(self, key): 832 if key in GEOVARS: 833 raise self.Error( 834 "You cannot set the value of a variable associated to the structure.\n" 835 "Use Structure objects to prepare the input file." 836 ) 837 838 def to_string(self, post=None, with_structure=True, with_pseudos=True, exclude=None): 839 r""" 840 String representation. 841 842 Args: 843 post: String that will be appended to the name of the variables 844 Note that post is usually autodetected when we have multiple datatasets 845 It is mainly used when we have an input file with a single dataset 846 so that we can prevent the code from adding "1" to the name of the variables 847 (In this case, indeed, Abinit complains if ndtset=1 is not specified 848 and we don't want ndtset=1 simply because the code will start to add 849 _DS1_ to all the input and output files. 850 with_structure: False if section with structure variables should not be printed. 851 with_pseudos: False if JSON section with pseudo data should not be added. 852 exclude: List of variable names that should be ignored. 853 """ 854 lines = [] 855 app = lines.append 856 857 if self.comment: 858 app("# " + self.comment.replace("\n", "\n#")) 859 860 post = post if post is not None else "" 861 exclude = set(exclude) if exclude is not None else set() 862 863 # Default is no sorting else alphabetical order. 864 keys = sorted([k for k, v in self.items() if k not in exclude and v is not None]) 865 866 # Extract the items from the dict and add the geo variables at the end 867 items = [(k, self[k]) for k in keys] 868 if with_structure: 869 items.extend(list(aobj.structure_to_abivars(self.structure).items())) 870 871 for name, value in items: 872 # Build variable, convert to string and append it 873 vname = name + post 874 app(str(InputVariable(vname, value))) 875 876 s = "\n".join(lines) 877 if not with_pseudos: 878 return s 879 880 # Add JSON section with pseudo potentials. 881 ppinfo = ["\n\n\n#<JSON>"] 882 d = {"pseudos": [p.as_dict() for p in self.pseudos]} 883 ppinfo.extend(json.dumps(d, indent=4).splitlines()) 884 ppinfo.append("</JSON>") 885 886 s += "\n#".join(ppinfo) 887 return s 888 889 @property 890 def comment(self): 891 """Optional string with comment. None if comment is not set.""" 892 try: 893 return self._comment 894 except AttributeError: 895 return None 896 897 def set_comment(self, comment): 898 """Set a comment to be included at the top of the file.""" 899 self._comment = comment 900 901 @property 902 def structure(self): 903 """The |Structure| object associated to this input.""" 904 return self._structure 905 906 def set_structure(self, structure): 907 """Set structure.""" 908 self._structure = as_structure(structure) 909 910 # Check volume 911 m = self.structure.lattice.matrix 912 if np.dot(np.cross(m[0], m[1]), m[2]) <= 0: 913 raise self.Error("The triple product of the lattice vector is negative. Use structure.abi_sanitize.") 914 915 return self._structure 916 917 # Helper functions to facilitate the specification of several variables. 918 def set_kmesh(self, ngkpt, shiftk, kptopt=1): 919 """ 920 Set the variables for the sampling of the BZ. 921 922 Args: 923 ngkpt: Monkhorst-Pack divisions 924 shiftk: List of shifts. 925 kptopt: Option for the generation of the mesh. 926 """ 927 shiftk = np.reshape(shiftk, (-1, 3)) 928 return self.set_vars(ngkpt=ngkpt, kptopt=kptopt, nshiftk=len(shiftk), shiftk=shiftk) 929 930 def set_gamma_sampling(self): 931 """Gamma-only sampling of the BZ.""" 932 return self.set_kmesh(ngkpt=(1, 1, 1), shiftk=(0, 0, 0)) 933 934 def set_kpath(self, ndivsm, kptbounds=None, iscf=-2): 935 """ 936 Set the variables for the computation of the electronic band structure. 937 938 Args: 939 ndivsm: Number of divisions for the smallest segment. 940 kptbounds: k-points defining the path in k-space. 941 If None, we use the default high-symmetry k-path defined in the pymatgen database. 942 """ 943 944 if kptbounds is None: 945 from pymatgen.symmetry.bandstructure import HighSymmKpath 946 947 hsym_kpath = HighSymmKpath(self.structure) 948 949 name2frac_coords = hsym_kpath.kpath["kpoints"] 950 kpath = hsym_kpath.kpath["path"] 951 952 frac_coords, names = [], [] 953 for segment in kpath: 954 for name in segment: 955 fc = name2frac_coords[name] 956 frac_coords.append(fc) 957 names.append(name) 958 kptbounds = np.array(frac_coords) 959 960 kptbounds = np.reshape(kptbounds, (-1, 3)) 961 # self.pop_vars(["ngkpt", "shiftk"]) ?? 962 963 return self.set_vars(kptbounds=kptbounds, kptopt=-(len(kptbounds) - 1), ndivsm=ndivsm, iscf=iscf) 964 965 def set_spin_mode(self, spin_mode): 966 """ 967 Set the variables used to the treat the spin degree of freedom. 968 Return dictionary with the variables that have been removed. 969 970 Args: 971 spin_mode: :class:`SpinMode` object or string. Possible values for string are: 972 973 - polarized 974 - unpolarized 975 - afm (anti-ferromagnetic) 976 - spinor (non-collinear magnetism) 977 - spinor_nomag (non-collinear, no magnetism) 978 """ 979 # Remove all variables used to treat spin 980 old_vars = self.pop_vars(["nsppol", "nspden", "nspinor"]) 981 self.add_abiobjects(aobj.SpinMode.as_spinmode(spin_mode)) 982 return old_vars 983 984 @property 985 def pseudos(self): 986 """List of |Pseudo| objects.""" 987 return self._pseudos 988 989 @property 990 def ispaw(self): 991 """True if PAW calculation.""" 992 return all(p.ispaw for p in self.pseudos) 993 994 @property 995 def isnc(self): 996 """True if norm-conserving calculation.""" 997 return all(p.isnc for p in self.pseudos) 998 999 def new_with_vars(self, *args, **kwargs): 1000 """ 1001 Return a new input with the given variables. 1002 1003 Example: 1004 new = input.new_with_vars(ecut=20) 1005 """ 1006 # Avoid modifications in self. 1007 new = self.deepcopy() 1008 new.set_vars(*args, **kwargs) 1009 return new 1010 1011 def pop_tolerances(self): 1012 """ 1013 Remove all the tolerance variables present in self. 1014 Return dictionary with the variables that have been removed. 1015 """ 1016 return self.remove_vars(_TOLVARS, strict=False) 1017 1018 def pop_irdvars(self): 1019 """ 1020 Remove all the `ird*` variables present in self. 1021 Return dictionary with the variables that have been removed. 1022 """ 1023 return self.remove_vars(_IRDVARS, strict=False) 1024 1025 1026class BasicMultiDataset: 1027 """ 1028 This object is essentially a list of BasicAbinitInput objects. 1029 that provides an easy-to-use interface to apply global changes to the 1030 the inputs stored in the objects. 1031 1032 Let's assume for example that multi contains two ``BasicAbinitInput`` objects and we 1033 want to set `ecut` to 1 in both dictionaries. The direct approach would be: 1034 1035 for inp in multi: 1036 inp.set_vars(ecut=1) 1037 1038 or alternatively: 1039 1040 for i in range(multi.ndtset): 1041 multi[i].set_vars(ecut=1) 1042 1043 BasicMultiDataset provides its own implementaion of __getattr__ so that one can simply use: 1044 1045 multi.set_vars(ecut=1) 1046 1047 multi.get("ecut") returns a list of values. It's equivalent to: 1048 1049 [inp["ecut"] for inp in multi] 1050 1051 Note that if "ecut" is not present in one of the input of multi, the corresponding entry is set to None. 1052 A default value can be specified with: 1053 1054 multi.get("paral_kgb", 0) 1055 1056 .. warning:: 1057 1058 BasicMultiDataset does not support calculations done with different sets of pseudopotentials. 1059 The inputs can have different crystalline structures (as long as the atom types are equal) 1060 but each input in BasicMultiDataset must have the same set of pseudopotentials. 1061 """ 1062 1063 Error = BasicAbinitInputError 1064 1065 @classmethod 1066 def from_inputs(cls, inputs): 1067 """Build object from a list of BasicAbinitInput objects.""" 1068 for inp in inputs: 1069 if any(p1 != p2 for p1, p2 in zip(inputs[0].pseudos, inp.pseudos)): 1070 raise ValueError("Pseudos must be consistent when from_inputs is invoked.") 1071 1072 # Build BasicMultiDataset from input structures and pseudos and add inputs. 1073 multi = cls( 1074 structure=[inp.structure for inp in inputs], 1075 pseudos=inputs[0].pseudos, 1076 ndtset=len(inputs), 1077 ) 1078 1079 # Add variables 1080 for inp, new_inp in zip(inputs, multi): 1081 new_inp.set_vars(**inp) 1082 1083 return multi 1084 1085 @classmethod 1086 def replicate_input(cls, input, ndtset): 1087 """Construct a multidataset with ndtset from the BasicAbinitInput input.""" 1088 multi = cls(input.structure, input.pseudos, ndtset=ndtset) 1089 1090 for inp in multi: 1091 inp.set_vars(**input) 1092 1093 return multi 1094 1095 def __init__(self, structure, pseudos, pseudo_dir="", ndtset=1): 1096 """ 1097 Args: 1098 structure: file with the structure, |Structure| object or dictionary with ABINIT geo variable 1099 Accepts also list of objects that can be converted to Structure object. 1100 In this case, however, ndtset must be equal to the length of the list. 1101 pseudos: String or list of string with the name of the pseudopotential files. 1102 pseudo_dir: Name of the directory where the pseudopotential files are located. 1103 ndtset: Number of datasets. 1104 """ 1105 # Setup of the pseudopotential files. 1106 if isinstance(pseudos, Pseudo): 1107 pseudos = [pseudos] 1108 1109 elif isinstance(pseudos, PseudoTable): 1110 pseudos = pseudos 1111 1112 elif all(isinstance(p, Pseudo) for p in pseudos): 1113 pseudos = PseudoTable(pseudos) 1114 1115 else: 1116 # String(s) 1117 pseudo_dir = os.path.abspath(pseudo_dir) 1118 pseudo_paths = [os.path.join(pseudo_dir, p) for p in list_strings(pseudos)] 1119 1120 missing = [p for p in pseudo_paths if not os.path.exists(p)] 1121 if missing: 1122 raise self.Error("Cannot find the following pseudopotential files:\n%s" % str(missing)) 1123 1124 pseudos = PseudoTable(pseudo_paths) 1125 1126 # Build the list of BasicAbinitInput objects. 1127 if ndtset <= 0: 1128 raise ValueError("ndtset %d cannot be <=0" % ndtset) 1129 1130 if not isinstance(structure, (list, tuple)): 1131 self._inputs = [BasicAbinitInput(structure=structure, pseudos=pseudos) for i in range(ndtset)] 1132 else: 1133 assert len(structure) == ndtset 1134 self._inputs = [BasicAbinitInput(structure=s, pseudos=pseudos) for s in structure] 1135 1136 @property 1137 def ndtset(self): 1138 """Number of inputs in self.""" 1139 return len(self) 1140 1141 @property 1142 def pseudos(self): 1143 """Pseudopotential objects.""" 1144 return self[0].pseudos 1145 1146 @property 1147 def ispaw(self): 1148 """True if PAW calculation.""" 1149 return all(p.ispaw for p in self.pseudos) 1150 1151 @property 1152 def isnc(self): 1153 """True if norm-conserving calculation.""" 1154 return all(p.isnc for p in self.pseudos) 1155 1156 def __len__(self): 1157 return len(self._inputs) 1158 1159 def __getitem__(self, key): 1160 return self._inputs[key] 1161 1162 def __iter__(self): 1163 return self._inputs.__iter__() 1164 1165 def __getattr__(self, name): 1166 _inputs = object.__getattribute__(self, "_inputs") 1167 m = getattr(_inputs[0], name) 1168 if m is None: 1169 raise AttributeError( 1170 "Cannot find attribute %s. Tried in %s and then in BasicAbinitInput object" 1171 % (self.__class__.__name__, name) 1172 ) 1173 isattr = not callable(m) 1174 1175 def on_all(*args, **kwargs): 1176 results = [] 1177 for obj in self._inputs: 1178 a = getattr(obj, name) 1179 # print("name", name, ", type:", type(a), "callable: ",callable(a)) 1180 if callable(a): 1181 results.append(a(*args, **kwargs)) 1182 else: 1183 results.append(a) 1184 1185 return results 1186 1187 if isattr: 1188 on_all = on_all() 1189 1190 return on_all 1191 1192 def __add__(self, other): 1193 """self + other""" 1194 if isinstance(other, BasicAbinitInput): 1195 new_mds = BasicMultiDataset.from_inputs(self) 1196 new_mds.append(other) 1197 return new_mds 1198 if isinstance(other, BasicMultiDataset): 1199 new_mds = BasicMultiDataset.from_inputs(self) 1200 new_mds.extend(other) 1201 return new_mds 1202 1203 raise NotImplementedError("Operation not supported") 1204 1205 def __radd__(self, other): 1206 if isinstance(other, BasicAbinitInput): 1207 new_mds = BasicMultiDataset.from_inputs([other]) 1208 new_mds.extend(self) 1209 elif isinstance(other, BasicMultiDataset): 1210 new_mds = BasicMultiDataset.from_inputs(other) 1211 new_mds.extend(self) 1212 else: 1213 raise NotImplementedError("Operation not supported") 1214 1215 def append(self, abinit_input): 1216 """Add a |BasicAbinitInput| to the list.""" 1217 assert isinstance(abinit_input, BasicAbinitInput) 1218 if any(p1 != p2 for p1, p2 in zip(abinit_input.pseudos, abinit_input.pseudos)): 1219 raise ValueError("Pseudos must be consistent when from_inputs is invoked.") 1220 self._inputs.append(abinit_input) 1221 1222 def extend(self, abinit_inputs): 1223 """Extends self with a list of |BasicAbinitInput| objects.""" 1224 assert all(isinstance(inp, BasicAbinitInput) for inp in abinit_inputs) 1225 for inp in abinit_inputs: 1226 if any(p1 != p2 for p1, p2 in zip(self[0].pseudos, inp.pseudos)): 1227 raise ValueError("Pseudos must be consistent when from_inputs is invoked.") 1228 self._inputs.extend(abinit_inputs) 1229 1230 def addnew_from(self, dtindex): 1231 """Add a new entry in the multidataset by copying the input with index ``dtindex``.""" 1232 self.append(self[dtindex].deepcopy()) 1233 1234 def split_datasets(self): 1235 """Return list of |BasicAbinitInput| objects..""" 1236 return self._inputs 1237 1238 def deepcopy(self): 1239 """Deep copy of the BasicMultiDataset.""" 1240 return copy.deepcopy(self) 1241 1242 @property 1243 def has_same_structures(self): 1244 """True if all inputs in BasicMultiDataset are equal.""" 1245 return all(self[0].structure == inp.structure for inp in self) 1246 1247 def __str__(self): 1248 return self.to_string() 1249 1250 def to_string(self, with_pseudos=True): 1251 """ 1252 String representation i.e. the input file read by Abinit. 1253 1254 Args: 1255 with_pseudos: False if JSON section with pseudo data should not be added. 1256 """ 1257 if self.ndtset > 1: 1258 # Multi dataset mode. 1259 lines = ["ndtset %d" % self.ndtset] 1260 1261 def has_same_variable(kref, vref, other_inp): 1262 """True if variable kref is present in other_inp with the same value.""" 1263 if kref not in other_inp: 1264 return False 1265 otherv = other_inp[kref] 1266 return np.array_equal(vref, otherv) 1267 1268 # Don't repeat variable that are common to the different datasets. 1269 # Put them in the `Global Variables` section and exclude these variables in inp.to_string 1270 global_vars = set() 1271 for k0, v0 in self[0].items(): 1272 isame = True 1273 for i in range(1, self.ndtset): 1274 isame = has_same_variable(k0, v0, self[i]) 1275 if not isame: 1276 break 1277 if isame: 1278 global_vars.add(k0) 1279 # print("global_vars vars", global_vars) 1280 1281 w = 92 1282 if global_vars: 1283 lines.append(w * "#") 1284 lines.append("### Global Variables.") 1285 lines.append(w * "#") 1286 for key in global_vars: 1287 vname = key 1288 lines.append(str(InputVariable(vname, self[0][key]))) 1289 1290 has_same_structures = self.has_same_structures 1291 if has_same_structures: 1292 # Write structure here and disable structure output in input.to_string 1293 lines.append(w * "#") 1294 lines.append("#" + ("STRUCTURE").center(w - 1)) 1295 lines.append(w * "#") 1296 for key, value in aobj.structure_to_abivars(self[0].structure).items(): 1297 vname = key 1298 lines.append(str(InputVariable(vname, value))) 1299 1300 for i, inp in enumerate(self): 1301 header = "### DATASET %d ###" % (i + 1) 1302 is_last = i == self.ndtset - 1 1303 s = inp.to_string( 1304 post=str(i + 1), 1305 with_pseudos=is_last and with_pseudos, 1306 with_structure=not has_same_structures, 1307 exclude=global_vars, 1308 ) 1309 if s: 1310 header = len(header) * "#" + "\n" + header + "\n" + len(header) * "#" + "\n" 1311 s = "\n" + header + s + "\n" 1312 1313 lines.append(s) 1314 1315 return "\n".join(lines) 1316 1317 # single datasets ==> don't append the dataset index to the variables. 1318 # this trick is needed because Abinit complains if ndtset is not specified 1319 # and we have variables that end with the dataset index e.g. acell1 1320 # We don't want to specify ndtset here since abinit will start to add DS# to 1321 # the input and output files thus complicating the algorithms we have to use to locate the files. 1322 return self[0].to_string(with_pseudos=with_pseudos) 1323 1324 def write(self, filepath="run.abi"): 1325 """ 1326 Write ``ndset`` input files to disk. The name of the file 1327 is constructed from the dataset index e.g. run0.abi 1328 """ 1329 root, ext = os.path.splitext(filepath) 1330 for i, inp in enumerate(self): 1331 p = root + "DS%d" % i + ext 1332 inp.write(filepath=p) 1333