1"""This module contains lookup table with the name of the ABINIT variables."""
2import os
3import warnings
4import numpy as np
5
6from pprint import pformat
7from monty.string import is_string, boxed
8from monty.functools import lazy_property
9from monty.termcolor import cprint
10from pymatgen.core.units import bohr_to_ang
11from abipy.core.structure import Structure, dataframes_from_structures
12from abipy.core.mixins import Has_Structure, TextFile, NotebookWriter
13from abipy.abio.abivar_database.variables import get_codevars
14
15__all__ = [
16    "is_abivar",
17    "is_abiunit",
18    "AbinitInputFile",
19    "AbinitInputParser",
20]
21
22
23def is_anaddb_var(varname):
24    """True if varname is a valid Anaddb variable."""
25    return varname in get_codevars()["anaddb"]
26
27
28def is_abivar(varname):
29    """True if s is an ABINIT variable."""
30    # Add include statement
31    # FIXME: These variables should be added to the database.
32    extra = ["include", "xyzfile"]
33    return varname in get_codevars()["abinit"] or varname in extra
34
35
36# TODO: Move to new directory
37ABI_OPERATORS = set(["sqrt", ])
38
39ABI_UNIT_NAMES = {
40    s.lower() for s in (
41        "au", "nm",
42        "Angstr", "Angstrom", "Angstroms", "Bohr", "Bohrs",
43        "eV", "Ha", "Hartree", "Hartrees", "K", "Ry", "Rydberg", "Rydbergs",
44        "T", "Tesla",)
45}
46
47
48def is_abiunit(s):
49    """
50    True if string is one of the units supported by the ABINIT parser
51    """
52    if not is_string(s): return False
53    return s.lower() in ABI_UNIT_NAMES
54
55
56def expand_star_syntax(s):
57    """
58    Evaluate star syntax. Return new string
59    Remember that Abinit does not accept white spaces.
60    For example `typat 2 * 1` is not valid.
61
62    >>> assert expand_star_syntax("3*2") == '2 2 2'
63    >>> assert expand_star_syntax("2 *1") == '1 1'
64    >>> assert expand_star_syntax("1 2*2") == '1 2 2'
65    >>> assert expand_star_syntax("*2") == '*2'
66    """
67    s = s.strip()
68    if "*" not in s:
69        return s
70    else:
71        # Handle e.g `pawecutdg*`
72        if s[0].isalpha() and s[-1] == "*": return s
73
74    s = s.replace("*", " * ").strip()
75    tokens = s.split()
76    #tokens = [c.rstrip().lstrip() for c in s.split()]
77
78    # Handle "*2" case i.e. return "*2"
79    if len(tokens) == 2 and tokens[0] == "*":
80        assert tokens[1] != "*"
81        return "".join(tokens)
82
83    #print(s, tokens)
84    l = []
85    while tokens:
86        c = tokens.pop(0)
87        if c == "*":
88            num = int(l.pop(-1))
89            val = tokens.pop(0)
90            l.extend(num * [val])
91        else:
92            l.append(c)
93
94    return " ".join(l)
95
96
97def str2array_bohr(obj):
98    if not is_string(obj):
99        return np.asarray(obj)
100
101    # Treat e.g. acell 3 * 1.0
102    obj = expand_star_syntax(obj)
103    # Numpy does not understand "0.00d0 0.00d0"
104    obj = obj.lower().replace("d", "e")
105
106    tokens = obj.split()
107    if not tokens[-1].isalpha():
108        # No unit
109        return np.fromstring(obj, sep=" ")
110
111    unit = tokens[-1]
112    if unit in ("angstr", "angstrom", "angstroms"):
113        return np.fromstring(" ".join(tokens[:-1]), sep=" ") / bohr_to_ang
114    elif unit in ("bohr", "bohrs", "au"):
115        return np.fromstring(" ".join(tokens[:-1]), sep=" ")
116    else:
117        raise ValueError("Don't know how to handle unit: %s" % str(unit))
118
119
120def str2array(obj, dtype=float):
121    if not is_string(obj): return np.asarray(obj)
122    if obj.startswith("*"):
123        raise ValueError("This case should be treated by the caller: %s" % str(obj))
124    s = expand_star_syntax(obj)
125    # Numpy does not understand "0.00d0 0.00d0"
126    s = s.lower().replace("d", "e")
127    return np.fromstring(s, sep=" ", dtype=dtype)
128
129
130class Dataset(dict, Has_Structure):
131
132    @lazy_property
133    def structure(self):
134        """
135        The initial structure associated to the dataset.
136        """
137
138        # First of all check whether the structure is defined through external file.
139        if "structure" in self:
140            s = self["structure"].replace('"', "")
141            filetype, path = s.split(":")
142            from abipy import abilab
143            with abilab.abiopen(path) as abifile:
144                return abifile.structure
145
146        # Get lattice.
147        kwargs = {}
148        if "angdeg" in self:
149            if "rprim" in self:
150                raise ValueError("rprim and angdeg cannot be used together!")
151            angdeg = str2array(self["angdeg"])
152            angdeg.shape = (3)
153            kwargs["angdeg"] = angdeg
154        else:
155            # Handle structure specified with rprim.
156            kwargs["rprim"] = str2array_bohr(self.get("rprim", "1.0 0 0 0 1 0 0 0 1"))
157
158        # Default value for acell.
159        acell = str2array_bohr(self.get("acell", "1.0 1.0 1.0"))
160
161        # Get important dimensions.
162        ntypat = int(self.get("ntypat", 1))
163        natom = int(self.get("natom", 1))
164
165        # znucl(npsp)
166        znucl = self["znucl"]
167        if znucl.startswith("*"):
168            i = znucl.find("*")
169            znucl_size = natom if "npsp" not in self else int(self["npsp"])
170            znucl = znucl_size * [float(znucl[i+1:])]
171        else:
172            znucl = str2array(self["znucl"])
173
174        # v67mbpt/Input/t12.in
175        typat = self["typat"]
176        if typat.startswith("*"):
177            i = typat.find("*")
178            typat = np.array(natom * [int(typat[i+1:])], dtype=int)
179        else:
180            typat = str2array(self["typat"], dtype=int)
181
182        # Extract atomic positions.
183        # Select first natom entries (needed if multidatasets with different natom)
184        #    # v3/Input/t05.in
185        typat = typat[:natom]
186        for k in ("xred", "xcart", "xangst"):
187            toarray = str2array_bohr if k == "xcart" else str2array
188            if k in self:
189                arr = np.reshape(toarray(self[k]), (-1, 3))
190                kwargs[k] = arr[:natom]
191                break
192        else:
193            raise ValueError("xred|xcart|xangst must be given in input")
194
195        try:
196            return Structure.from_abivars(acell=acell, znucl=znucl, typat=typat, **kwargs)
197        except Exception as exc:
198            print("Wrong inputs passed to Structure.from_abivars:")
199            print("acell:", acell, "znucl:", znucl, "typat:", typat, "kwargs:", kwargs, sep="\n")
200            raise exc
201
202    def get_vars(self):
203        """
204        Return dictionary with variables. The variables describing the crystalline structure
205        are removed from the output dictionary.
206        """
207        geovars = {"acell", "angdeg", "rprim", "ntypat", "natom", "znucl", "typat", "xred", "xcart", "xangst"}
208        return {k: self[k] for k in self if k not in geovars}
209
210    def __str__(self):
211        return self.to_string()
212
213    def to_string(self, post=None, mode="text", verbose=0):
214        """
215        String representation.
216
217        Args:
218            post: String that will be appended to the name of the variables
219            mode: Either `text` or `html` if HTML output with links is wanted.
220            verbose: Verbosity level.
221        """
222        post = post if post is not None else ""
223        if mode == "html":
224            from abipy.abio.abivars_db import get_abinit_variables
225            var_database = get_abinit_variables()
226
227        lines = []
228        app = lines.append
229        for k in sorted(list(self.keys())):
230            vname = k + post
231            if mode == "html": vname = var_database[k].html_link(label=vname)
232            app("%s %s" % (vname, str(self[k])))
233
234        return "\n".join(lines) if mode == "text" else "\n".join(lines).replace("\n", "<br>")
235
236    def _repr_html_(self):
237        """Integration with jupyter_ notebooks."""
238        return self.to_string(mode="html")
239
240
241class AbinitInputFile(TextFile, Has_Structure, NotebookWriter):
242    """
243    This object parses the Abinit input file, stores the variables in
244    dict-like objects (Datasets) and build `Structure` objects from
245    the input variables. Mainly used for inspecting the structure
246    declared in the Abinit input file.
247    """
248
249    @classmethod
250    def from_string(cls, string):
251        """Build the object from string."""
252        import tempfile
253        _, filename = tempfile.mkstemp(suffix=".abi", text=True)
254        with open(filename, "wt") as fh:
255            fh.write(string)
256        return cls(filename)
257
258    def __init__(self, filepath):
259        super().__init__(filepath)
260
261        with open(filepath, "rt") as fh:
262            self.string = fh.read()
263
264        self.datasets = AbinitInputParser().parse(self.string)
265        self.ndtset = len(self.datasets)
266
267    def __str__(self):
268        return self.to_string()
269
270    def to_string(self, verbose=0):
271        """String representation."""
272        lines = []
273        app = lines.append
274        header = 10 * "=" + " Input File " + 10 * "="
275        app(header)
276        app(self.string)
277        app(len(header) * "=" + "\n")
278
279        # Print info on structure(s).
280        if self.structure is not None:
281            app(self.structure.spget_summary())
282        else:
283            structures = [dt.structure for dt in self.datasets]
284            app("Input file contains %d structures:" % len(structures))
285            for i, structure in enumerate(structures):
286                app(boxed("Dataset: %d" % (i+1)))
287                app(structure.spget_summary())
288                app("")
289
290            dfs = dataframes_from_structures(structures, index=[i+1 for i in range(self.ndtset)])
291            app(boxed("Tabular view (each row corresponds to a dataset structure)"))
292            app("")
293            app("Lattice parameters:")
294            app(str(dfs.lattice))
295            app("")
296            app("Atomic positions:")
297            app(str(dfs.coords))
298
299        return "\n".join(lines)
300
301    @lazy_property
302    def has_multi_structures(self):
303        """True if input defines multiple structures."""
304        return self.structure is None
305
306    def _repr_html_(self):
307        """Integration with jupyter notebooks."""
308        from abipy.abio.abivars_db import repr_html_from_abinit_string
309        return repr_html_from_abinit_string(self.string)
310        #return self.to_string(mode="html"))
311
312    def close(self):
313        """NOP, required by ABC."""
314
315    @lazy_property
316    def structure(self):
317        """
318        The structure defined in the input file.
319
320        If the input file contains multiple datasets **AND** the datasets
321        have different structures, this property returns None.
322        In this case, one has to access the structure of the individual datasets.
323        For example:
324
325            input.datasets[0].structure
326
327        gives the structure of the first dataset.
328        """
329        for dt in self.datasets[1:]:
330            if dt.structure != self.datasets[0].structure:
331                warnings.warn("Datasets have different structures. Returning None. Use input.datasets[i].structure")
332                return None
333
334        return self.datasets[0].structure
335
336    def yield_figs(self, **kwargs):  # pragma: no cover
337        """
338        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
339        """
340        if not self.has_multi_structures:
341            yield self.structure.plot(show=False)
342            yield self.structure.plot_bz(show=False)
343        else:
344            for dt in self.datasets:
345                yield dt.structure.plot(show=False)
346                yield dt.structure.plot_bz(show=False)
347
348    def write_notebook(self, nbpath=None):
349        """
350        Write an ipython notebook to nbpath. If nbpath is None, a temporay file in the current
351        working directory is created. Return path to the notebook.
352        """
353        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
354
355        nb.cells.extend([
356            nbv.new_code_cell("abinp = abilab.abiopen('%s')" % self.filepath),
357            nbv.new_code_cell("print(abinp)"),
358        ])
359
360        if self.has_multi_structures:
361            nb.cells.extend([
362                nbv.new_code_cell("""\
363for dataset in inp.datasets:
364    print(dataset.structure)"""),
365            ])
366
367        if self.ndtset > 1:
368            nb.cells.extend([
369                nbv.new_code_cell("""\
370for dataset in abinp.datasets:
371    print(dataset)"""),
372            ])
373
374        return self._write_nb_nbpath(nb, nbpath)
375
376
377class AbinitInputParser(object):
378    verbose = 0
379
380    def parse(self, s):
381        """
382        This function receives a string `s` with the Abinit input and return
383        a list of :class:`Dataset` objects.
384        """
385        # TODO: Parse PSEUDO section if present!
386        # Remove comments from lines.
387        lines = []
388        for line in s.splitlines():
389            line.strip()
390            i = line.find("#")
391            if i != -1: line = line[:i]
392            i = line.find("!")
393            if i != -1: line = line[:i]
394            if line: lines.append(line)
395
396        # 1) Build string of the form "var1 value1 var2 value2"
397        # 2) split string in tokens.
398        # 3) Evaluate star syntax i.e. "3*2" ==> '2 2 2'
399        # 4) Evaluate operators e.g. sqrt(0.75)
400        tokens = " ".join(lines).split()
401        # Step 3 is needed because we are gonna use python to evaluate the operators and
402        # in abinit `2*sqrt(0.75)` means `sqrt(0.75) sqrt(0.75)` and not math multiplication!
403        if self.verbose: print("tokens", tokens)
404        new_tokens = []
405        for t in tokens:
406            l = expand_star_syntax(t).split()
407            #print("t", t, "l", l)
408            new_tokens.extend(l)
409        tokens = new_tokens
410        if self.verbose: print("new_tokens", new_tokens)
411
412        tokens = self.eval_abinit_operators(tokens)
413        #print(tokens)
414
415        varpos = []
416        for pos, tok in enumerate(tokens):
417            #if not isnewvar(ok): continue
418
419            if tok[0].isalpha():
420                # Either new variable, string defining the unit or operator e.g. sqrt
421                if is_abiunit(tok) or tok in ABI_OPERATORS or "?" in tok:
422                    continue
423
424                # Have new variable
425                if tok[-1].isdigit(): # and "?" not in tok:
426                    # Handle dataset index.
427                    l = []
428                    for i, c in enumerate(tok[::-1]):
429                        if c.isalpha(): break
430                        l.append(c)
431                    else:
432                        raise ValueError("Cannot find dataset index in token: %s" % tok)
433                    l.reverse()
434                    #if not is_abivar(tok):
435                        #continue
436                        #raise ValueError("Expecting variable but got: %s" % tok)
437
438                #print("new var", tok, pos)
439                varpos.append(pos)
440
441        varpos.append(len(tokens))
442
443        # Build dict {varname --> value_string}
444        dvars = {}
445        for i, pos in enumerate(varpos[:-1]):
446            varname = tokens[pos]
447            if pos + 2 == len(tokens):
448                dvars[varname] = tokens[-1]
449            else:
450                dvars[varname] = " ".join(tokens[pos+1: varpos[i+1]])
451
452        #print(dvars)
453        err_lines = []
454        for k, v in dvars.items():
455            if not v:
456                err_lines.append("key `%s` was not parsed correctly (empty value)" % k)
457        if err_lines:
458            raise RuntimeError("\n".join(err_lines))
459
460        # Get value of ndtset.
461        ndtset = int(dvars.pop("ndtset", 1))
462        udtset = dvars.pop("udtset", None)
463        jdtset = dvars.pop("jdtset", None)
464        if udtset is not None:
465            raise NotImplementedError("udtset is not supported")
466
467        # Build list of datasets.
468        datasets = [Dataset() for i in range(ndtset)]
469
470        # Treat all variables without a dataset index
471        kv_list = list(dvars.items())
472        for k, v in kv_list:
473            if k[-1].isdigit() or any(c in k for c in ("?", ":", "+", "*")): continue
474            for d in datasets: d[k] = v
475            dvars.pop(k)
476
477        # Treat all variables with a dataset index except those with "?", ":", "+"
478        kv_list = list(dvars.items())
479        for k, v in kv_list:
480            if any(c in k for c in ("?", ":", "+", "*")): continue
481            varname, idt = self.varname_dtindex(k)
482            dvars.pop(k)
483            #if varname == "angdeg": raise ValueError("got angdeg")
484            if idt > ndtset:
485                if self.verbose: print("Ignoring key: %s because ndtset: %d" % (k, ndtset))
486                continue
487            datasets[idt-1][varname] = v
488
489        # Now treat series e.g. ecut: 10 ecut+ 5 (NB: ? is not treated here)
490        kv_list = list(dvars.items())
491        for k, v in kv_list:
492            if "?" in k: continue
493            if ":" not in k: continue
494            # TODO units
495            vname = k[:-1]
496            start = str2array(dvars.pop(k))
497
498            # Handle ecut+ or ecut*
499            incr = dvars.pop(vname + "+", None)
500            if incr is not None:
501                incr = str2array(incr)
502                for dt in datasets:
503                    dt[vname] = start.copy()
504                    start += incr
505
506            else:
507                mult = dvars.pop(vname + "*")
508                mult = str2array(mult)
509                for dt in datasets:
510                    dt[vname] = start.copy()
511                    start *= mult
512
513        # Consistency check
514        # 1) dvars should be empty
515        if dvars:
516            raise ValueError("Don't know how handle variables in:\n%s" % pformat(dvars), indent=4)
517
518        # 2) Keys in datasets should be valid Abinit input variables.
519        wrong = []
520        for i, dt in enumerate(datasets):
521            wlist = [k for k in dt if not is_abivar(k)]
522            if wlist:
523                wrong.extend(("dataset %d" % i, wlist))
524        if wrong:
525            raise ValueError("Found variables that are not registered in the abipy database:\n%s" % pformat(wrong, indent=4))
526
527        # 3) We don't support spg builder: dataset.structure will fail or, even worse,
528        #    spglib will segfault so it's better to raise here!
529        for dt in datasets:
530            if "spgroup" in dt or "nobj" in dt:
531                raise NotImplementedError(
532                    "Abinit spgroup builder is not supported. Structure must be given explicitly!")
533
534        if jdtset is not None:
535            # Return the datasets selected by jdtset.
536            datasets = [datasets[i-1] for i in np.fromstring(jdtset, sep=" ", dtype=int)]
537
538        return datasets
539
540    @staticmethod
541    def eval_abinit_operators(tokens):
542        """
543        Receive a list of strings, find the occurences of operators supported
544        in the input file (e.g. sqrt), evalute the expression and return new list of strings.
545
546        .. note:
547
548            This function is not recursive hence expr like sqrt(1/2) are not supported
549        """
550        import math # noqa: F401
551        import re
552        re_sqrt = re.compile(r"[+|-]?sqrt\((.+)\)")
553
554        values = []
555        for tok in tokens:
556            m = re_sqrt.match(tok)
557            if m:
558                tok = tok.replace("sqrt", "math.sqrt")
559                tok = str(eval(tok))
560            if "/" in tok:
561                tok = str(eval(tok))
562            values.append(tok)
563        return values
564
565    @staticmethod
566    def varname_dtindex(tok):
567        """
568        >>> p = AbinitInputParser()
569        >>> assert p.varname_dtindex("acell1") == ("acell", 1)
570        >>> assert p.varname_dtindex("fa1k2") == ("fa1k", 2)
571        """
572        l = []
573        for i, c in enumerate(tok[::-1]):
574            if c.isalpha(): break
575            l.append(c)
576        else:
577            raise ValueError("Cannot find dataset index in: %s" % tok)
578
579        assert i > 0
580        l.reverse()
581        dtidx = int("".join(l))
582        varname = tok[:len(tok)-i]
583
584        return varname, dtidx
585
586
587def validate_input_parser(abitests_dir=None, input_files=None):
588    """
589    validate/test AbinitInput parser.
590
591    Args:
592        dirpath: Abinit tests directory.
593        input_files: List of Abinit input files.
594
595    Return: Exit code.
596    """
597    def is_abinit_input(path):
598        """
599        True if path is one of the input files used in the Abinit Test suite.
600        """
601        if path.endswith(".abi"): return True
602        if not path.endswith(".in"): return False
603
604        with open(path, "rt") as fh:
605            for line in fh:
606                if "executable" in line and "abinit" in line: return True
607            return False
608
609    # Files are collected in paths.
610    paths = []
611
612    if abitests_dir is not None:
613        print("Analyzing directory %s for input files" % abitests_dir)
614
615        for dirpath, dirnames, filenames in os.walk(abitests_dir):
616            for fname in filenames:
617                path = os.path.join(dirpath, fname)
618                if is_abinit_input(path): paths.append(path)
619
620            #import ast
621            #init_path = os.path.join(dirpath, "__init__.py")
622            #with open(init_path, "rt") as f:
623            #    source = f.read()
624            #    start = source.find("inp_files = [")
625            #    if start == -1:
626            #        print("ignoring ", init_path)
627            #        continue
628            #    stop = source.find("]", start)
629            #    if stop == -1:
630            #        raise ValueError("Invalid code in %s" % init_path)
631            #    print(init_path)
632            #    inp_basenames = ast.literal_eval(source[start:stop+1])
633            #    print(int_basenames)
634
635    if input_files is not None:
636        print("Analyzing files ", str(input_files))
637        for arg in input_files:
638            if is_abinit_input(arg): paths.append(arg)
639
640    nfiles = len(paths)
641    if nfiles == 0:
642        cprint("Empty list of input files.", "red")
643        return 0
644
645    print("Found %d Abinit input files" % len(paths))
646    errpaths = []
647    for path in paths:
648        print(path + ": ", end="")
649        try:
650            inp = AbinitInputFile.from_file(path)
651            s = str(inp)
652            cprint("OK", "green")
653        except Exception as exc:
654            if not isinstance(exc, NotImplementedError):
655                cprint("FAILED", "red")
656                errpaths.append(path)
657                import traceback
658                print(traceback.format_exc())
659                #print("[%s]: Exception:\n%s" % (path, str(exc)))
660                #with open(path, "rt") as fh:
661                #    print(10*"=" + "Input File" + 10*"=")
662                #    print(fh.read())
663                #    print()
664            else:
665                cprint("NOTIMPLEMENTED", "magenta")
666
667    if errpaths:
668        cprint("failed: %d/%d [%.1f%%]" % (len(errpaths), nfiles, 100 * len(errpaths)/nfiles), "red")
669        for i, epath in enumerate(errpaths):
670            cprint("[%d] %s" % (i, epath), "red")
671    else:
672        cprint("All input files successfully parsed!", "green")
673
674    return len(errpaths)
675