1"""
2This module defines a simplified interface for generating ABINIT input files.
3Note that not all the features of Abinit are supported by BasicAbinitInput.
4For a more comprehensive implementation, use the AbinitInput object provided by AbiPy.
5"""
6
7import abc
8import copy
9import json
10import logging
11import os
12from collections import OrderedDict, namedtuple
13from collections.abc import Mapping, MutableMapping
14from enum import Enum
15
16import numpy as np
17from monty.collections import AttrDict
18from monty.json import MSONable
19from monty.string import is_string, list_strings
20
21from pymatgen.core.structure import Structure
22from pymatgen.io.abinit import abiobjects as aobj
23from pymatgen.io.abinit.pseudos import Pseudo, PseudoTable
24from pymatgen.io.abinit.variable import InputVariable
25from pymatgen.util.serialization import pmg_serialize
26
27logger = logging.getLogger(__file__)
28
29
30# List of Abinit variables used to specify the structure.
31# This variables should not be passed to set_vars since
32# they will be generated with structure.to_abivars()
33GEOVARS = set(
34    [
35        "acell",
36        "rprim",
37        "rprimd" "angdeg",
38        "xred",
39        "xcart",
40        "xangst",
41        "znucl",
42        "typat",
43        "ntypat",
44        "natom",
45    ]
46)
47
48# Variables defining tolerances (used in pop_tolerances)
49_TOLVARS = set(
50    [
51        "toldfe",
52        "tolvrs",
53        "tolwfr",
54        "tolrff",
55        "toldff",
56        "tolimg",
57        "tolmxf",
58        "tolrde",
59    ]
60)
61
62# Variables defining tolerances for the SCF cycle that are mutally exclusive
63_TOLVARS_SCF = set(
64    [
65        "toldfe",
66        "tolvrs",
67        "tolwfr",
68        "tolrff",
69        "toldff",
70    ]
71)
72
73# Variables determining if data files should be read in input
74_IRDVARS = set(
75    [
76        "irdbseig",
77        "irdbsreso",
78        "irdhaydock",
79        "irdddk",
80        "irdden",
81        "ird1den",
82        "irdqps",
83        "irdkss",
84        "irdscr",
85        "irdsuscep",
86        "irdvdw",
87        "irdwfk",
88        "irdwfkfine",
89        "irdwfq",
90        "ird1wf",
91    ]
92)
93
94# Name of the (default) tolerance used by the runlevels.
95_runl2tolname = {
96    "scf": "tolvrs",
97    "nscf": "tolwfr",
98    "dfpt": "toldfe",  # ?
99    "screening": "toldfe",  # dummy
100    "sigma": "toldfe",  # dummy
101    "bse": "toldfe",  # ?
102    "relax": "tolrff",
103}
104
105# Tolerances for the different levels of accuracy.
106
107T = namedtuple("T", "low normal high")
108_tolerances = {
109    "toldfe": T(1.0e-7, 1.0e-8, 1.0e-9),
110    "tolvrs": T(1.0e-7, 1.0e-8, 1.0e-9),
111    "tolwfr": T(1.0e-15, 1.0e-17, 1.0e-19),
112    "tolrff": T(0.04, 0.02, 0.01),
113}
114del T
115
116
117# Default values used if user does not specify them
118_DEFAULTS = dict(
119    kppa=1000,
120)
121
122
123def as_structure(obj):
124    """
125    Convert obj into a Structure. Accepts:
126
127        - Structure object.
128        - Filename
129        - Dictionaries (MSONable format or dictionaries with abinit variables).
130    """
131    if isinstance(obj, Structure):
132        return obj
133
134    if is_string(obj):
135        return Structure.from_file(obj)
136
137    if isinstance(obj, Mapping):
138        if "@module" in obj:
139            return Structure.from_dict(obj)
140        return aobj.structure_from_abivars(cls=None, **obj)
141
142    raise TypeError("Don't know how to convert %s into a structure" % type(obj))
143
144
145class ShiftMode(Enum):
146    """
147    Class defining the mode to be used for the shifts.
148    G: Gamma centered
149    M: Monkhorst-Pack ((0.5, 0.5, 0.5))
150    S: Symmetric. Respects the chksymbreak with multiple shifts
151    O: OneSymmetric. Respects the chksymbreak with a single shift (as in 'S' if a single shift is given, gamma
152        centered otherwise.
153    """
154
155    GammaCentered = "G"
156    MonkhorstPack = "M"
157    Symmetric = "S"
158    OneSymmetric = "O"
159
160    @classmethod
161    def from_object(cls, obj):
162        """
163        Returns an instance of ShiftMode based on the type of object passed. Converts strings to ShiftMode depending
164        on the iniital letter of the string. G for GammaCenterd, M for MonkhorstPack,
165        S for Symmetric, O for OneSymmetric.
166        Case insensitive.
167        """
168        if isinstance(obj, cls):
169            return obj
170        if is_string(obj):
171            return cls(obj[0].upper())
172        raise TypeError("The object provided is not handled: type %s" % type(obj))
173
174
175def _stopping_criterion(runlevel, accuracy):
176    """Return the stopping criterion for this runlevel with the given accuracy."""
177    tolname = _runl2tolname[runlevel]
178    return {tolname: getattr(_tolerances[tolname], accuracy)}
179
180
181def _find_ecut_pawecutdg(ecut, pawecutdg, pseudos, accuracy):
182    """Return a |AttrDict| with the value of ``ecut`` and ``pawecutdg``."""
183    # Get ecut and pawecutdg from the pseudo hints.
184    if ecut is None or (pawecutdg is None and any(p.ispaw for p in pseudos)):
185        has_hints = all(p.has_hints for p in pseudos)
186
187    if ecut is None:
188        if has_hints:
189            ecut = max(p.hint_for_accuracy(accuracy).ecut for p in pseudos)
190        else:
191            raise RuntimeError("ecut is None but pseudos do not provide hints for ecut")
192
193    if pawecutdg is None and any(p.ispaw for p in pseudos):
194        if has_hints:
195            pawecutdg = max(p.hint_for_accuracy(accuracy).pawecutdg for p in pseudos)
196        else:
197            raise RuntimeError("pawecutdg is None but pseudos do not provide hints")
198
199    return AttrDict(ecut=ecut, pawecutdg=pawecutdg)
200
201
202def _find_scf_nband(structure, pseudos, electrons, spinat=None):
203    """Find the value of ``nband``."""
204    if electrons.nband is not None:
205        return electrons.nband
206
207    nsppol, smearing = electrons.nsppol, electrons.smearing
208
209    # Number of valence electrons including possible extra charge
210    nval = num_valence_electrons(structure, pseudos)
211    nval -= electrons.charge
212
213    # First guess (semiconductors)
214    nband = nval // 2
215
216    # TODO: Find better algorithm
217    # If nband is too small we may kill the job, increase nband and restart
218    # but this change could cause problems in the other steps of the calculation
219    # if the change is not propagated e.g. phonons in metals.
220    if smearing:
221        # metallic occupation
222        nband = max(np.ceil(nband * 1.2), nband + 10)
223    else:
224        nband = max(np.ceil(nband * 1.1), nband + 4)
225
226    # Increase number of bands based on the starting magnetization
227    if nsppol == 2 and spinat is not None:
228        nband += np.ceil(max(np.sum(spinat, axis=0)) / 2.0)
229
230    # Force even nband (easier to divide among procs, mandatory if nspinor == 2)
231    nband += nband % 2
232    return int(nband)
233
234
235def _get_shifts(shift_mode, structure):
236    """
237    Gives the shifts based on the selected shift mode and on the symmetry of the structure.
238    G: Gamma centered
239    M: Monkhorst-Pack ((0.5, 0.5, 0.5))
240    S: Symmetric. Respects the chksymbreak with multiple shifts
241    O: OneSymmetric. Respects the chksymbreak with a single shift (as in 'S' if a single shift is given, gamma
242        centered otherwise.
243
244    Note: for some cases (e.g. body centered tetragonal), both the Symmetric and OneSymmetric may fail to satisfy the
245        ``chksymbreak`` condition (Abinit input variable).
246    """
247    if shift_mode == ShiftMode.GammaCentered:
248        return ((0, 0, 0),)
249    if shift_mode == ShiftMode.MonkhorstPack:
250        return ((0.5, 0.5, 0.5),)
251    if shift_mode == ShiftMode.Symmetric:
252        return calc_shiftk(structure)
253    if shift_mode == ShiftMode.OneSymmetric:
254        shifts = calc_shiftk(structure)
255        if len(shifts) == 1:
256            return shifts
257        return ((0, 0, 0),)
258
259    raise ValueError("invalid shift_mode: `%s`" % str(shift_mode))
260
261
262def gs_input(
263    structure,
264    pseudos,
265    kppa=None,
266    ecut=None,
267    pawecutdg=None,
268    scf_nband=None,
269    accuracy="normal",
270    spin_mode="polarized",
271    smearing="fermi_dirac:0.1 eV",
272    charge=0.0,
273    scf_algorithm=None,
274):
275    """
276    Returns a |BasicAbinitInput| for ground-state calculation.
277
278    Args:
279        structure: |Structure| object.
280        pseudos: List of filenames or list of |Pseudo| objects or |PseudoTable| object.
281        kppa: Defines the sampling used for the SCF run. Defaults to 1000 if not given.
282        ecut: cutoff energy in Ha (if None, ecut is initialized from the pseudos according to accuracy)
283        pawecutdg: cutoff energy in Ha for PAW double-grid (if None, pawecutdg is initialized from the pseudos
284                   according to accuracy)
285        scf_nband: Number of bands for SCF run. If scf_nband is None, nband is automatically initialized
286                   from the list of pseudos, the structure and the smearing option.
287        accuracy: Accuracy of the calculation.
288        spin_mode: Spin polarization.
289        smearing: Smearing technique.
290        charge: Electronic charge added to the unit cell.
291        scf_algorithm: Algorithm used for solving of the SCF cycle.
292    """
293    multi = ebands_input(
294        structure,
295        pseudos,
296        kppa=kppa,
297        ndivsm=0,
298        ecut=ecut,
299        pawecutdg=pawecutdg,
300        scf_nband=scf_nband,
301        accuracy=accuracy,
302        spin_mode=spin_mode,
303        smearing=smearing,
304        charge=charge,
305        scf_algorithm=scf_algorithm,
306    )
307
308    return multi[0]
309
310
311def ebands_input(
312    structure,
313    pseudos,
314    kppa=None,
315    nscf_nband=None,
316    ndivsm=15,
317    ecut=None,
318    pawecutdg=None,
319    scf_nband=None,
320    accuracy="normal",
321    spin_mode="polarized",
322    smearing="fermi_dirac:0.1 eV",
323    charge=0.0,
324    scf_algorithm=None,
325    dos_kppa=None,
326):
327    """
328    Returns a |BasicMultiDataset| object for band structure calculations.
329
330    Args:
331        structure: |Structure| object.
332        pseudos: List of filenames or list of |Pseudo| objects or |PseudoTable| object.
333        kppa: Defines the sampling used for the SCF run. Defaults to 1000 if not given.
334        nscf_nband: Number of bands included in the NSCF run. Set to scf_nband + 10 if None.
335        ndivsm: Number of divisions used to sample the smallest segment of the k-path.
336                if 0, only the GS input is returned in multi[0].
337        ecut: cutoff energy in Ha (if None, ecut is initialized from the pseudos according to accuracy)
338        pawecutdg: cutoff energy in Ha for PAW double-grid (if None, pawecutdg is initialized from the pseudos
339            according to accuracy)
340        scf_nband: Number of bands for SCF run. If scf_nband is None, nband is automatically initialized
341            from the list of pseudos, the structure and the smearing option.
342        accuracy: Accuracy of the calculation.
343        spin_mode: Spin polarization.
344        smearing: Smearing technique.
345        charge: Electronic charge added to the unit cell.
346        scf_algorithm: Algorithm used for solving of the SCF cycle.
347        dos_kppa: Scalar or List of integers with the number of k-points per atom
348            to be used for the computation of the DOS (None if DOS is not wanted).
349    """
350    structure = as_structure(structure)
351
352    if dos_kppa is not None and not isinstance(dos_kppa, (list, tuple)):
353        dos_kppa = [dos_kppa]
354
355    multi = BasicMultiDataset(structure, pseudos, ndtset=2 if dos_kppa is None else 2 + len(dos_kppa))
356
357    # Set the cutoff energies.
358    multi.set_vars(_find_ecut_pawecutdg(ecut, pawecutdg, multi.pseudos, accuracy))
359
360    # SCF calculation.
361    kppa = _DEFAULTS.get("kppa") if kppa is None else kppa
362    scf_ksampling = aobj.KSampling.automatic_density(structure, kppa, chksymbreak=0)
363    scf_electrons = aobj.Electrons(
364        spin_mode=spin_mode,
365        smearing=smearing,
366        algorithm=scf_algorithm,
367        charge=charge,
368        nband=scf_nband,
369        fband=None,
370    )
371
372    if scf_electrons.nband is None:
373        scf_electrons.nband = _find_scf_nband(structure, multi.pseudos, scf_electrons, multi[0].get("spinat", None))
374
375    multi[0].set_vars(scf_ksampling.to_abivars())
376    multi[0].set_vars(scf_electrons.to_abivars())
377    multi[0].set_vars(_stopping_criterion("scf", accuracy))
378    if ndivsm == 0:
379        return multi
380
381    # Band structure calculation.
382    nscf_ksampling = aobj.KSampling.path_from_structure(ndivsm, structure)
383    nscf_nband = scf_electrons.nband + 10 if nscf_nband is None else nscf_nband
384    nscf_electrons = aobj.Electrons(
385        spin_mode=spin_mode,
386        smearing=smearing,
387        algorithm={"iscf": -2},
388        charge=charge,
389        nband=nscf_nband,
390        fband=None,
391    )
392
393    multi[1].set_vars(nscf_ksampling.to_abivars())
394    multi[1].set_vars(nscf_electrons.to_abivars())
395    multi[1].set_vars(_stopping_criterion("nscf", accuracy))
396
397    # DOS calculation with different values of kppa.
398    if dos_kppa is not None:
399        for i, kppa_ in enumerate(dos_kppa):
400            dos_ksampling = aobj.KSampling.automatic_density(structure, kppa_, chksymbreak=0)
401            # dos_ksampling = aobj.KSampling.monkhorst(dos_ngkpt, shiftk=dos_shiftk, chksymbreak=0)
402            dos_electrons = aobj.Electrons(
403                spin_mode=spin_mode,
404                smearing=smearing,
405                algorithm={"iscf": -2},
406                charge=charge,
407                nband=nscf_nband,
408            )
409            dt = 2 + i
410            multi[dt].set_vars(dos_ksampling.to_abivars())
411            multi[dt].set_vars(dos_electrons.to_abivars())
412            multi[dt].set_vars(_stopping_criterion("nscf", accuracy))
413
414    return multi
415
416
417def ion_ioncell_relax_input(
418    structure,
419    pseudos,
420    kppa=None,
421    nband=None,
422    ecut=None,
423    pawecutdg=None,
424    accuracy="normal",
425    spin_mode="polarized",
426    smearing="fermi_dirac:0.1 eV",
427    charge=0.0,
428    scf_algorithm=None,
429    shift_mode="Monkhorst-pack",
430):
431    """
432    Returns a |BasicMultiDataset| for a structural relaxation. The first dataset optmizes the
433    atomic positions at fixed unit cell. The second datasets optimizes both ions and unit cell parameters.
434
435    Args:
436        structure: |Structure| object.
437        pseudos: List of filenames or list of |Pseudo| objects or |PseudoTable| object.
438        kppa: Defines the sampling used for the Brillouin zone.
439        nband: Number of bands included in the SCF run.
440        accuracy: Accuracy of the calculation.
441        spin_mode: Spin polarization.
442        smearing: Smearing technique.
443        charge: Electronic charge added to the unit cell.
444        scf_algorithm: Algorithm used for the solution of the SCF cycle.
445    """
446    structure = as_structure(structure)
447    multi = BasicMultiDataset(structure, pseudos, ndtset=2)
448
449    # Set the cutoff energies.
450    multi.set_vars(_find_ecut_pawecutdg(ecut, pawecutdg, multi.pseudos, accuracy))
451
452    kppa = _DEFAULTS.get("kppa") if kppa is None else kppa
453
454    shift_mode = ShiftMode.from_object(shift_mode)
455    shifts = _get_shifts(shift_mode, structure)
456    ksampling = aobj.KSampling.automatic_density(structure, kppa, chksymbreak=0, shifts=shifts)
457    electrons = aobj.Electrons(
458        spin_mode=spin_mode,
459        smearing=smearing,
460        algorithm=scf_algorithm,
461        charge=charge,
462        nband=nband,
463        fband=None,
464    )
465
466    if electrons.nband is None:
467        electrons.nband = _find_scf_nband(structure, multi.pseudos, electrons, multi[0].get("spinat", None))
468
469    ion_relax = aobj.RelaxationMethod.atoms_only(atoms_constraints=None)
470    ioncell_relax = aobj.RelaxationMethod.atoms_and_cell(atoms_constraints=None)
471
472    multi.set_vars(electrons.to_abivars())
473    multi.set_vars(ksampling.to_abivars())
474
475    multi[0].set_vars(ion_relax.to_abivars())
476    multi[0].set_vars(_stopping_criterion("relax", accuracy))
477
478    multi[1].set_vars(ioncell_relax.to_abivars())
479    multi[1].set_vars(_stopping_criterion("relax", accuracy))
480
481    return multi
482
483
484def calc_shiftk(structure, symprec=0.01, angle_tolerance=5):
485    """
486    Find the values of ``shiftk`` and ``nshiftk`` appropriated for the sampling of the Brillouin zone.
487
488    When the primitive vectors of the lattice do NOT form a FCC or a BCC lattice,
489    the usual (shifted) Monkhorst-Pack grids are formed by using nshiftk=1 and shiftk 0.5 0.5 0.5 .
490    This is often the preferred k point sampling. For a non-shifted Monkhorst-Pack grid,
491    use `nshiftk=1` and `shiftk 0.0 0.0 0.0`, but there is little reason to do that.
492
493    When the primitive vectors of the lattice form a FCC lattice, with rprim::
494
495            0.0 0.5 0.5
496            0.5 0.0 0.5
497            0.5 0.5 0.0
498
499    the (very efficient) usual Monkhorst-Pack sampling will be generated by using nshiftk= 4 and shiftk::
500
501        0.5 0.5 0.5
502        0.5 0.0 0.0
503        0.0 0.5 0.0
504        0.0 0.0 0.5
505
506    When the primitive vectors of the lattice form a BCC lattice, with rprim::
507
508           -0.5  0.5  0.5
509            0.5 -0.5  0.5
510            0.5  0.5 -0.5
511
512    the usual Monkhorst-Pack sampling will be generated by using nshiftk= 2 and shiftk::
513
514            0.25  0.25  0.25
515           -0.25 -0.25 -0.25
516
517    However, the simple sampling nshiftk=1 and shiftk 0.5 0.5 0.5 is excellent.
518
519    For hexagonal lattices with hexagonal axes, e.g. rprim::
520
521            1.0  0.0       0.0
522           -0.5  sqrt(3)/2 0.0
523            0.0  0.0       1.0
524
525    one can use nshiftk= 1 and shiftk 0.0 0.0 0.5
526    In rhombohedral axes, e.g. using angdeg 3*60., this corresponds to shiftk 0.5 0.5 0.5,
527    to keep the shift along the symmetry axis.
528
529    Returns:
530        Suggested value of shiftk.
531    """
532    # Find lattice type.
533    from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
534
535    sym = SpacegroupAnalyzer(structure, symprec=symprec, angle_tolerance=angle_tolerance)
536    lattice_type, spg_symbol = sym.get_lattice_type(), sym.get_space_group_symbol()
537
538    # Check if the cell is primitive
539    is_primitive = len(sym.find_primitive()) == len(structure)
540
541    # Generate the appropriate set of shifts.
542    shiftk = None
543
544    if is_primitive:
545        if lattice_type == "cubic":
546            if "F" in spg_symbol:
547                # FCC
548                shiftk = [0.5, 0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.5]
549
550            elif "I" in spg_symbol:
551                # BCC
552                shiftk = [0.25, 0.25, 0.25, -0.25, -0.25, -0.25]
553                # shiftk = [0.5, 0.5, 05])
554
555        elif lattice_type == "hexagonal":
556            # Find the hexagonal axis and set the shift along it.
557            for i, angle in enumerate(structure.lattice.angles):
558                if abs(angle - 120) < 1.0:
559                    j = (i + 1) % 3
560                    k = (i + 2) % 3
561                    hex_ax = [ax for ax in range(3) if ax not in [j, k]][0]
562                    break
563            else:
564                raise ValueError("Cannot find hexagonal axis")
565
566            shiftk = [0.0, 0.0, 0.0]
567            shiftk[hex_ax] = 0.5
568
569        elif lattice_type == "tetragonal":
570            if "I" in spg_symbol:
571                # BCT
572                shiftk = [0.25, 0.25, 0.25, -0.25, -0.25, -0.25]
573
574    if shiftk is None:
575        # Use default value.
576        shiftk = [0.5, 0.5, 0.5]
577
578    return np.reshape(shiftk, (-1, 3))
579
580
581def num_valence_electrons(structure, pseudos):
582    """
583    Returns the number of valence electrons.
584
585    Args:
586        pseudos: List of |Pseudo| objects or list of filenames.
587    """
588    nval, table = 0, PseudoTable.as_table(pseudos)
589    for site in structure:
590        pseudo = table.pseudo_with_symbol(site.specie.symbol)
591        nval += pseudo.Z_val
592
593    return int(nval) if int(nval) == nval else nval
594
595
596class AbstractInput(MutableMapping, metaclass=abc.ABCMeta):
597    """
598    Abstract class defining the methods that must be implemented by Input objects.
599    """
600
601    # ABC protocol: __delitem__, __getitem__, __iter__, __len__, __setitem__
602    def __delitem__(self, key):
603        return self.vars.__delitem__(key)
604
605    def __getitem__(self, key):
606        return self.vars.__getitem__(key)
607
608    def __iter__(self):
609        return self.vars.__iter__()
610
611    def __len__(self):
612        return len(self.vars)
613
614    def __setitem__(self, key, value):
615        self._check_varname(key)
616        return self.vars.__setitem__(key, value)
617
618    def __repr__(self):
619        return "<%s at %s>" % (self.__class__.__name__, id(self))
620
621    def __str__(self):
622        return self.to_string()
623
624    def write(self, filepath="run.abi"):
625        """
626        Write the input file to file to ``filepath``.
627        """
628        dirname = os.path.dirname(os.path.abspath(filepath))
629        if not os.path.exists(dirname):
630            os.makedirs(dirname)
631
632        # Write the input file.
633        with open(filepath, "wt") as fh:
634            fh.write(str(self))
635
636    def deepcopy(self):
637        """Deep copy of the input."""
638        return copy.deepcopy(self)
639
640    def set_vars(self, *args, **kwargs):
641        """
642        Set the value of the variables.
643        Return dict with the variables added to the input.
644
645        Example:
646
647            input.set_vars(ecut=10, ionmov=3)
648        """
649        kwargs.update(dict(*args))
650        for varname, varvalue in kwargs.items():
651            self[varname] = varvalue
652        return kwargs
653
654    def set_vars_ifnotin(self, *args, **kwargs):
655        """
656        Set the value of the variables but only if the variable is not already present.
657        Return dict with the variables added to the input.
658
659        Example:
660
661            input.set_vars(ecut=10, ionmov=3)
662        """
663        kwargs.update(dict(*args))
664        added = {}
665        for varname, varvalue in kwargs.items():
666            if varname not in self:
667                self[varname] = varvalue
668                added[varname] = varvalue
669        return added
670
671    def pop_vars(self, keys):
672        """
673        Remove the variables listed in keys.
674        Return dictionary with the variables that have been removed.
675        Unlike remove_vars, no exception is raised if the variables are not in the input.
676
677        Args:
678            keys: string or list of strings with variable names.
679
680        Example:
681            inp.pop_vars(["ionmov", "optcell", "ntime", "dilatmx"])
682        """
683        return self.remove_vars(keys, strict=False)
684
685    def remove_vars(self, keys, strict=True):
686        """
687        Remove the variables listed in keys.
688        Return dictionary with the variables that have been removed.
689
690        Args:
691            keys: string or list of strings with variable names.
692            strict: If True, KeyError is raised if at least one variable is not present.
693        """
694        removed = {}
695        for key in list_strings(keys):
696            if strict and key not in self:
697                raise KeyError("key: %s not in self:\n %s" % (key, list(self.keys())))
698            if key in self:
699                removed[key] = self.pop(key)
700
701        return removed
702
703    @abc.abstractproperty
704    def vars(self):
705        """Dictionary with the input variables. Used to implement dict-like interface."""
706
707    @abc.abstractmethod
708    def _check_varname(self, key):
709        """Check if key is a valid name. Raise self.Error if not valid."""
710
711    @abc.abstractmethod
712    def to_string(self):
713        """Returns a string with the input."""
714
715
716class BasicAbinitInputError(Exception):
717    """Base error class for exceptions raised by ``BasicAbinitInput``."""
718
719
720class BasicAbinitInput(AbstractInput, MSONable):
721    """
722    This object stores the ABINIT variables for a single dataset.
723    """
724
725    Error = BasicAbinitInputError
726
727    def __init__(
728        self,
729        structure,
730        pseudos,
731        pseudo_dir=None,
732        comment=None,
733        abi_args=None,
734        abi_kwargs=None,
735    ):
736        """
737        Args:
738            structure: Parameters defining the crystalline structure. Accepts |Structure| object
739            file with structure (CIF, netcdf file, ...) or dictionary with ABINIT geo variables.
740            pseudos: Pseudopotentials to be used for the calculation. Accepts: string or list of strings
741                with the name of the pseudopotential files, list of |Pseudo| objects
742                or |PseudoTable| object.
743            pseudo_dir: Name of the directory where the pseudopotential files are located.
744            ndtset: Number of datasets.
745            comment: Optional string with a comment that will be placed at the beginning of the file.
746            abi_args: list of tuples (key, value) with the initial set of variables. Default: Empty
747            abi_kwargs: Dictionary with the initial set of variables. Default: Empty
748        """
749        # Internal dict with variables. we use an ordered dict so that
750        # variables will be likely grouped by `topics` when we fill the input.
751        abi_args = [] if abi_args is None else abi_args
752        for key, value in abi_args:
753            self._check_varname(key)
754
755        abi_kwargs = {} if abi_kwargs is None else abi_kwargs
756        for key in abi_kwargs:
757            self._check_varname(key)
758
759        args = list(abi_args)[:]
760        args.extend(list(abi_kwargs.items()))
761
762        self._vars = OrderedDict(args)
763        self.set_structure(structure)
764
765        if pseudo_dir is not None:
766            pseudo_dir = os.path.abspath(pseudo_dir)
767            if not os.path.exists(pseudo_dir):
768                raise self.Error("Directory %s does not exist" % pseudo_dir)
769            pseudos = [os.path.join(pseudo_dir, p) for p in list_strings(pseudos)]
770
771        try:
772            self._pseudos = PseudoTable.as_table(pseudos).get_pseudos_for_structure(self.structure)
773        except ValueError as exc:
774            raise self.Error(str(exc))
775
776        if comment is not None:
777            self.set_comment(comment)
778
779    @pmg_serialize
780    def as_dict(self):
781        """
782        JSON interface used in pymatgen for easier serialization.
783        """
784        # Use a list of (key, value) to serialize the OrderedDict
785        abi_args = []
786        for key, value in self.items():
787            if isinstance(value, np.ndarray):
788                value = value.tolist()
789            abi_args.append((key, value))
790
791        return dict(
792            structure=self.structure.as_dict(),
793            pseudos=[p.as_dict() for p in self.pseudos],
794            comment=self.comment,
795            abi_args=abi_args,
796        )
797
798    @property
799    def vars(self):
800        """Dictionary with variables."""
801        return self._vars
802
803    @classmethod
804    def from_dict(cls, d):
805        """
806        JSON interface used in pymatgen for easier serialization.
807        """
808        pseudos = [Pseudo.from_file(p["filepath"]) for p in d["pseudos"]]
809        return cls(d["structure"], pseudos, comment=d["comment"], abi_args=d["abi_args"])
810
811    def add_abiobjects(self, *abi_objects):
812        """
813        This function receive a list of ``AbiVarable`` objects and add
814        the corresponding variables to the input.
815        """
816        d = {}
817        for obj in abi_objects:
818            if not hasattr(obj, "to_abivars"):
819                raise TypeError("type %s: %s does not have `to_abivars` method" % (type(obj), repr(obj)))
820            d.update(self.set_vars(obj.to_abivars()))
821        return d
822
823    def __setitem__(self, key, value):
824        if key in _TOLVARS_SCF and hasattr(self, "_vars") and any(t in self._vars and t != key for t in _TOLVARS_SCF):
825            logger.info(
826                "Replacing previously set tolerance variable: {0}.".format(self.remove_vars(_TOLVARS_SCF, strict=False))
827            )
828
829        return super().__setitem__(key, value)
830
831    def _check_varname(self, key):
832        if key in GEOVARS:
833            raise self.Error(
834                "You cannot set the value of a variable associated to the structure.\n"
835                "Use Structure objects to prepare the input file."
836            )
837
838    def to_string(self, post=None, with_structure=True, with_pseudos=True, exclude=None):
839        r"""
840        String representation.
841
842        Args:
843            post: String that will be appended to the name of the variables
844                Note that post is usually autodetected when we have multiple datatasets
845                It is mainly used when we have an input file with a single dataset
846                so that we can prevent the code from adding "1" to the name of the variables
847                (In this case, indeed, Abinit complains if ndtset=1 is not specified
848                and we don't want ndtset=1 simply because the code will start to add
849                _DS1_ to all the input and output files.
850            with_structure: False if section with structure variables should not be printed.
851            with_pseudos: False if JSON section with pseudo data should not be added.
852            exclude: List of variable names that should be ignored.
853        """
854        lines = []
855        app = lines.append
856
857        if self.comment:
858            app("# " + self.comment.replace("\n", "\n#"))
859
860        post = post if post is not None else ""
861        exclude = set(exclude) if exclude is not None else set()
862
863        # Default is no sorting else alphabetical order.
864        keys = sorted([k for k, v in self.items() if k not in exclude and v is not None])
865
866        # Extract the items from the dict and add the geo variables at the end
867        items = [(k, self[k]) for k in keys]
868        if with_structure:
869            items.extend(list(aobj.structure_to_abivars(self.structure).items()))
870
871        for name, value in items:
872            # Build variable, convert to string and append it
873            vname = name + post
874            app(str(InputVariable(vname, value)))
875
876        s = "\n".join(lines)
877        if not with_pseudos:
878            return s
879
880        # Add JSON section with pseudo potentials.
881        ppinfo = ["\n\n\n#<JSON>"]
882        d = {"pseudos": [p.as_dict() for p in self.pseudos]}
883        ppinfo.extend(json.dumps(d, indent=4).splitlines())
884        ppinfo.append("</JSON>")
885
886        s += "\n#".join(ppinfo)
887        return s
888
889    @property
890    def comment(self):
891        """Optional string with comment. None if comment is not set."""
892        try:
893            return self._comment
894        except AttributeError:
895            return None
896
897    def set_comment(self, comment):
898        """Set a comment to be included at the top of the file."""
899        self._comment = comment
900
901    @property
902    def structure(self):
903        """The |Structure| object associated to this input."""
904        return self._structure
905
906    def set_structure(self, structure):
907        """Set structure."""
908        self._structure = as_structure(structure)
909
910        # Check volume
911        m = self.structure.lattice.matrix
912        if np.dot(np.cross(m[0], m[1]), m[2]) <= 0:
913            raise self.Error("The triple product of the lattice vector is negative. Use structure.abi_sanitize.")
914
915        return self._structure
916
917    # Helper functions to facilitate the specification of several variables.
918    def set_kmesh(self, ngkpt, shiftk, kptopt=1):
919        """
920        Set the variables for the sampling of the BZ.
921
922        Args:
923            ngkpt: Monkhorst-Pack divisions
924            shiftk: List of shifts.
925            kptopt: Option for the generation of the mesh.
926        """
927        shiftk = np.reshape(shiftk, (-1, 3))
928        return self.set_vars(ngkpt=ngkpt, kptopt=kptopt, nshiftk=len(shiftk), shiftk=shiftk)
929
930    def set_gamma_sampling(self):
931        """Gamma-only sampling of the BZ."""
932        return self.set_kmesh(ngkpt=(1, 1, 1), shiftk=(0, 0, 0))
933
934    def set_kpath(self, ndivsm, kptbounds=None, iscf=-2):
935        """
936        Set the variables for the computation of the electronic band structure.
937
938        Args:
939            ndivsm: Number of divisions for the smallest segment.
940            kptbounds: k-points defining the path in k-space.
941                If None, we use the default high-symmetry k-path defined in the pymatgen database.
942        """
943
944        if kptbounds is None:
945            from pymatgen.symmetry.bandstructure import HighSymmKpath
946
947            hsym_kpath = HighSymmKpath(self.structure)
948
949            name2frac_coords = hsym_kpath.kpath["kpoints"]
950            kpath = hsym_kpath.kpath["path"]
951
952            frac_coords, names = [], []
953            for segment in kpath:
954                for name in segment:
955                    fc = name2frac_coords[name]
956                    frac_coords.append(fc)
957                    names.append(name)
958            kptbounds = np.array(frac_coords)
959
960        kptbounds = np.reshape(kptbounds, (-1, 3))
961        # self.pop_vars(["ngkpt", "shiftk"]) ??
962
963        return self.set_vars(kptbounds=kptbounds, kptopt=-(len(kptbounds) - 1), ndivsm=ndivsm, iscf=iscf)
964
965    def set_spin_mode(self, spin_mode):
966        """
967        Set the variables used to the treat the spin degree of freedom.
968        Return dictionary with the variables that have been removed.
969
970        Args:
971            spin_mode: :class:`SpinMode` object or string. Possible values for string are:
972
973            - polarized
974            - unpolarized
975            - afm (anti-ferromagnetic)
976            - spinor (non-collinear magnetism)
977            - spinor_nomag (non-collinear, no magnetism)
978        """
979        # Remove all variables used to treat spin
980        old_vars = self.pop_vars(["nsppol", "nspden", "nspinor"])
981        self.add_abiobjects(aobj.SpinMode.as_spinmode(spin_mode))
982        return old_vars
983
984    @property
985    def pseudos(self):
986        """List of |Pseudo| objects."""
987        return self._pseudos
988
989    @property
990    def ispaw(self):
991        """True if PAW calculation."""
992        return all(p.ispaw for p in self.pseudos)
993
994    @property
995    def isnc(self):
996        """True if norm-conserving calculation."""
997        return all(p.isnc for p in self.pseudos)
998
999    def new_with_vars(self, *args, **kwargs):
1000        """
1001        Return a new input with the given variables.
1002
1003        Example:
1004            new = input.new_with_vars(ecut=20)
1005        """
1006        # Avoid modifications in self.
1007        new = self.deepcopy()
1008        new.set_vars(*args, **kwargs)
1009        return new
1010
1011    def pop_tolerances(self):
1012        """
1013        Remove all the tolerance variables present in self.
1014        Return dictionary with the variables that have been removed.
1015        """
1016        return self.remove_vars(_TOLVARS, strict=False)
1017
1018    def pop_irdvars(self):
1019        """
1020        Remove all the `ird*` variables present in self.
1021        Return dictionary with the variables that have been removed.
1022        """
1023        return self.remove_vars(_IRDVARS, strict=False)
1024
1025
1026class BasicMultiDataset:
1027    """
1028    This object is essentially a list of BasicAbinitInput objects.
1029    that provides an easy-to-use interface to apply global changes to the
1030    the inputs stored in the objects.
1031
1032    Let's assume for example that multi contains two ``BasicAbinitInput`` objects and we
1033    want to set `ecut` to 1 in both dictionaries. The direct approach would be:
1034
1035        for inp in multi:
1036            inp.set_vars(ecut=1)
1037
1038    or alternatively:
1039
1040        for i in range(multi.ndtset):
1041            multi[i].set_vars(ecut=1)
1042
1043    BasicMultiDataset provides its own implementaion of __getattr__ so that one can simply use:
1044
1045        multi.set_vars(ecut=1)
1046
1047        multi.get("ecut") returns a list of values. It's equivalent to:
1048
1049            [inp["ecut"] for inp in multi]
1050
1051        Note that if "ecut" is not present in one of the input of multi, the corresponding entry is set to None.
1052        A default value can be specified with:
1053
1054            multi.get("paral_kgb", 0)
1055
1056    .. warning::
1057
1058        BasicMultiDataset does not support calculations done with different sets of pseudopotentials.
1059        The inputs can have different crystalline structures (as long as the atom types are equal)
1060        but each input in BasicMultiDataset must have the same set of pseudopotentials.
1061    """
1062
1063    Error = BasicAbinitInputError
1064
1065    @classmethod
1066    def from_inputs(cls, inputs):
1067        """Build object from a list of BasicAbinitInput objects."""
1068        for inp in inputs:
1069            if any(p1 != p2 for p1, p2 in zip(inputs[0].pseudos, inp.pseudos)):
1070                raise ValueError("Pseudos must be consistent when from_inputs is invoked.")
1071
1072        # Build BasicMultiDataset from input structures and pseudos and add inputs.
1073        multi = cls(
1074            structure=[inp.structure for inp in inputs],
1075            pseudos=inputs[0].pseudos,
1076            ndtset=len(inputs),
1077        )
1078
1079        # Add variables
1080        for inp, new_inp in zip(inputs, multi):
1081            new_inp.set_vars(**inp)
1082
1083        return multi
1084
1085    @classmethod
1086    def replicate_input(cls, input, ndtset):
1087        """Construct a multidataset with ndtset from the BasicAbinitInput input."""
1088        multi = cls(input.structure, input.pseudos, ndtset=ndtset)
1089
1090        for inp in multi:
1091            inp.set_vars(**input)
1092
1093        return multi
1094
1095    def __init__(self, structure, pseudos, pseudo_dir="", ndtset=1):
1096        """
1097        Args:
1098            structure: file with the structure, |Structure| object or dictionary with ABINIT geo variable
1099                Accepts also list of objects that can be converted to Structure object.
1100                In this case, however, ndtset must be equal to the length of the list.
1101            pseudos: String or list of string with the name of the pseudopotential files.
1102            pseudo_dir: Name of the directory where the pseudopotential files are located.
1103            ndtset: Number of datasets.
1104        """
1105        # Setup of the pseudopotential files.
1106        if isinstance(pseudos, Pseudo):
1107            pseudos = [pseudos]
1108
1109        elif isinstance(pseudos, PseudoTable):
1110            pseudos = pseudos
1111
1112        elif all(isinstance(p, Pseudo) for p in pseudos):
1113            pseudos = PseudoTable(pseudos)
1114
1115        else:
1116            # String(s)
1117            pseudo_dir = os.path.abspath(pseudo_dir)
1118            pseudo_paths = [os.path.join(pseudo_dir, p) for p in list_strings(pseudos)]
1119
1120            missing = [p for p in pseudo_paths if not os.path.exists(p)]
1121            if missing:
1122                raise self.Error("Cannot find the following pseudopotential files:\n%s" % str(missing))
1123
1124            pseudos = PseudoTable(pseudo_paths)
1125
1126        # Build the list of BasicAbinitInput objects.
1127        if ndtset <= 0:
1128            raise ValueError("ndtset %d cannot be <=0" % ndtset)
1129
1130        if not isinstance(structure, (list, tuple)):
1131            self._inputs = [BasicAbinitInput(structure=structure, pseudos=pseudos) for i in range(ndtset)]
1132        else:
1133            assert len(structure) == ndtset
1134            self._inputs = [BasicAbinitInput(structure=s, pseudos=pseudos) for s in structure]
1135
1136    @property
1137    def ndtset(self):
1138        """Number of inputs in self."""
1139        return len(self)
1140
1141    @property
1142    def pseudos(self):
1143        """Pseudopotential objects."""
1144        return self[0].pseudos
1145
1146    @property
1147    def ispaw(self):
1148        """True if PAW calculation."""
1149        return all(p.ispaw for p in self.pseudos)
1150
1151    @property
1152    def isnc(self):
1153        """True if norm-conserving calculation."""
1154        return all(p.isnc for p in self.pseudos)
1155
1156    def __len__(self):
1157        return len(self._inputs)
1158
1159    def __getitem__(self, key):
1160        return self._inputs[key]
1161
1162    def __iter__(self):
1163        return self._inputs.__iter__()
1164
1165    def __getattr__(self, name):
1166        _inputs = object.__getattribute__(self, "_inputs")
1167        m = getattr(_inputs[0], name)
1168        if m is None:
1169            raise AttributeError(
1170                "Cannot find attribute %s. Tried in %s and then in BasicAbinitInput object"
1171                % (self.__class__.__name__, name)
1172            )
1173        isattr = not callable(m)
1174
1175        def on_all(*args, **kwargs):
1176            results = []
1177            for obj in self._inputs:
1178                a = getattr(obj, name)
1179                # print("name", name, ", type:", type(a), "callable: ",callable(a))
1180                if callable(a):
1181                    results.append(a(*args, **kwargs))
1182                else:
1183                    results.append(a)
1184
1185            return results
1186
1187        if isattr:
1188            on_all = on_all()
1189
1190        return on_all
1191
1192    def __add__(self, other):
1193        """self + other"""
1194        if isinstance(other, BasicAbinitInput):
1195            new_mds = BasicMultiDataset.from_inputs(self)
1196            new_mds.append(other)
1197            return new_mds
1198        if isinstance(other, BasicMultiDataset):
1199            new_mds = BasicMultiDataset.from_inputs(self)
1200            new_mds.extend(other)
1201            return new_mds
1202
1203        raise NotImplementedError("Operation not supported")
1204
1205    def __radd__(self, other):
1206        if isinstance(other, BasicAbinitInput):
1207            new_mds = BasicMultiDataset.from_inputs([other])
1208            new_mds.extend(self)
1209        elif isinstance(other, BasicMultiDataset):
1210            new_mds = BasicMultiDataset.from_inputs(other)
1211            new_mds.extend(self)
1212        else:
1213            raise NotImplementedError("Operation not supported")
1214
1215    def append(self, abinit_input):
1216        """Add a |BasicAbinitInput| to the list."""
1217        assert isinstance(abinit_input, BasicAbinitInput)
1218        if any(p1 != p2 for p1, p2 in zip(abinit_input.pseudos, abinit_input.pseudos)):
1219            raise ValueError("Pseudos must be consistent when from_inputs is invoked.")
1220        self._inputs.append(abinit_input)
1221
1222    def extend(self, abinit_inputs):
1223        """Extends self with a list of |BasicAbinitInput| objects."""
1224        assert all(isinstance(inp, BasicAbinitInput) for inp in abinit_inputs)
1225        for inp in abinit_inputs:
1226            if any(p1 != p2 for p1, p2 in zip(self[0].pseudos, inp.pseudos)):
1227                raise ValueError("Pseudos must be consistent when from_inputs is invoked.")
1228        self._inputs.extend(abinit_inputs)
1229
1230    def addnew_from(self, dtindex):
1231        """Add a new entry in the multidataset by copying the input with index ``dtindex``."""
1232        self.append(self[dtindex].deepcopy())
1233
1234    def split_datasets(self):
1235        """Return list of |BasicAbinitInput| objects.."""
1236        return self._inputs
1237
1238    def deepcopy(self):
1239        """Deep copy of the BasicMultiDataset."""
1240        return copy.deepcopy(self)
1241
1242    @property
1243    def has_same_structures(self):
1244        """True if all inputs in BasicMultiDataset are equal."""
1245        return all(self[0].structure == inp.structure for inp in self)
1246
1247    def __str__(self):
1248        return self.to_string()
1249
1250    def to_string(self, with_pseudos=True):
1251        """
1252        String representation i.e. the input file read by Abinit.
1253
1254        Args:
1255            with_pseudos: False if JSON section with pseudo data should not be added.
1256        """
1257        if self.ndtset > 1:
1258            # Multi dataset mode.
1259            lines = ["ndtset %d" % self.ndtset]
1260
1261            def has_same_variable(kref, vref, other_inp):
1262                """True if variable kref is present in other_inp with the same value."""
1263                if kref not in other_inp:
1264                    return False
1265                otherv = other_inp[kref]
1266                return np.array_equal(vref, otherv)
1267
1268            # Don't repeat variable that are common to the different datasets.
1269            # Put them in the `Global Variables` section and exclude these variables in inp.to_string
1270            global_vars = set()
1271            for k0, v0 in self[0].items():
1272                isame = True
1273                for i in range(1, self.ndtset):
1274                    isame = has_same_variable(k0, v0, self[i])
1275                    if not isame:
1276                        break
1277                if isame:
1278                    global_vars.add(k0)
1279            # print("global_vars vars", global_vars)
1280
1281            w = 92
1282            if global_vars:
1283                lines.append(w * "#")
1284                lines.append("### Global Variables.")
1285                lines.append(w * "#")
1286                for key in global_vars:
1287                    vname = key
1288                    lines.append(str(InputVariable(vname, self[0][key])))
1289
1290            has_same_structures = self.has_same_structures
1291            if has_same_structures:
1292                # Write structure here and disable structure output in input.to_string
1293                lines.append(w * "#")
1294                lines.append("#" + ("STRUCTURE").center(w - 1))
1295                lines.append(w * "#")
1296                for key, value in aobj.structure_to_abivars(self[0].structure).items():
1297                    vname = key
1298                    lines.append(str(InputVariable(vname, value)))
1299
1300            for i, inp in enumerate(self):
1301                header = "### DATASET %d ###" % (i + 1)
1302                is_last = i == self.ndtset - 1
1303                s = inp.to_string(
1304                    post=str(i + 1),
1305                    with_pseudos=is_last and with_pseudos,
1306                    with_structure=not has_same_structures,
1307                    exclude=global_vars,
1308                )
1309                if s:
1310                    header = len(header) * "#" + "\n" + header + "\n" + len(header) * "#" + "\n"
1311                    s = "\n" + header + s + "\n"
1312
1313                lines.append(s)
1314
1315            return "\n".join(lines)
1316
1317        # single datasets ==> don't append the dataset index to the variables.
1318        # this trick is needed because Abinit complains if ndtset is not specified
1319        # and we have variables that end with the dataset index e.g. acell1
1320        # We don't want to specify ndtset here since abinit will start to add DS# to
1321        # the input and output files thus complicating the algorithms we have to use to locate the files.
1322        return self[0].to_string(with_pseudos=with_pseudos)
1323
1324    def write(self, filepath="run.abi"):
1325        """
1326        Write ``ndset`` input files to disk. The name of the file
1327        is constructed from the dataset index e.g. run0.abi
1328        """
1329        root, ext = os.path.splitext(filepath)
1330        for i, inp in enumerate(self):
1331            p = root + "DS%d" % i + ext
1332            inp.write(filepath=p)
1333