1# coding: utf-8
2"""Wrappers for ABINIT main executables"""
3import os
4import numpy as np
5
6from monty.string import list_strings
7from io import StringIO
8
9import logging
10logger = logging.getLogger(__name__)
11
12__author__ = "Matteo Giantomassi"
13__copyright__ = "Copyright 2013, The Materials Project"
14__version__ = "0.1"
15__maintainer__ = "Matteo Giantomassi"
16__email__ = "gmatteo at gmail.com"
17__status__ = "Development"
18__date__ = "$Feb 21, 2013M$"
19
20__all__ = [
21    "Mrgscr",
22    "Mrgddb",
23    "Mrgdvdb",
24]
25
26
27class ExecError(Exception):
28    """Error class raised by :class:`ExecWrapper`"""
29
30
31class ExecWrapper(object):
32    """Base class that runs an executable in a subprocess."""
33    Error = ExecError
34
35    def __init__(self, manager=None, executable=None, verbose=0):
36        """
37        Args:
38            manager: :class:`TaskManager` object responsible for the submission of the jobs.
39                if manager is None, the default manager is used.
40            executable: path to the executable.
41            verbose: Verbosity level.
42        """
43        from .tasks import TaskManager
44        self.manager = manager if manager is not None else TaskManager.from_user_config()
45        self.manager = self.manager.to_shell_manager(mpi_procs=1)
46
47        self.executable = executable if executable is not None else self.name
48        assert os.path.basename(self.executable) == self.name
49        self.verbose = int(verbose)
50
51    def __str__(self):
52        return "%s" % self.executable
53
54    @property
55    def name(self):
56        return self._name
57
58    def execute(self, workdir, exec_args=None):
59        # Try to execute binary without and with mpirun.
60        try:
61            return self._execute(workdir, with_mpirun=True, exec_args=exec_args)
62        except self.Error:
63            return self._execute(workdir, with_mpirun=False, exec_args=exec_args)
64
65    def _execute(self, workdir, with_mpirun=False, exec_args=None):
66        """
67        Execute the executable in a subprocess inside workdir.
68
69        Some executables fail if we try to launch them with mpirun.
70        Use with_mpirun=False to run the binary without it.
71        """
72        qadapter = self.manager.qadapter
73        if not with_mpirun: qadapter.name = None
74        if self.verbose:
75            print("Working in:", workdir)
76
77        script = qadapter.get_script_str(
78            job_name=self.name,
79            launch_dir=workdir,
80            executable=self.executable,
81            qout_path="qout_file.path",
82            qerr_path="qerr_file.path",
83            stdin=self.stdin_fname,
84            stdout=self.stdout_fname,
85            stderr=self.stderr_fname,
86            exec_args=exec_args
87        )
88
89        # Write the script.
90        script_file = os.path.join(workdir, "run" + self.name + ".sh")
91        with open(script_file, "w") as fh:
92            fh.write(script)
93            os.chmod(script_file, 0o740)
94
95        qjob, process = qadapter.submit_to_queue(script_file)
96        self.stdout_data, self.stderr_data = process.communicate()
97        self.returncode = process.returncode
98        #raise self.Error("%s returned %s\n cmd_str: %s" % (self, self.returncode, self.cmd_str))
99
100        return self.returncode
101
102
103class Mrgscr(ExecWrapper):
104    _name = "mrgscr"
105
106    def merge_qpoints(self, workdir, files_to_merge, out_prefix):
107        """
108        Execute mrgscr inside directory `workdir` to merge `files_to_merge`.
109        Produce new file with prefix `out_prefix`
110        """
111        # We work with absolute paths.
112        files_to_merge = [os.path.abspath(s) for s in list_strings(files_to_merge)]
113        nfiles = len(files_to_merge)
114
115        if self.verbose:
116            print("Will merge %d files with output_prefix %s" % (nfiles, out_prefix))
117            for (i, f) in enumerate(files_to_merge):
118                print(" [%d] %s" % (i, f))
119
120        if nfiles == 1:
121            raise self.Error("merge_qpoints does not support nfiles == 1")
122
123        self.stdin_fname, self.stdout_fname, self.stderr_fname = \
124            map(os.path.join, 3 * [workdir], ["mrgscr.stdin", "mrgscr.stdout", "mrgscr.stderr"])
125
126        inp = StringIO()
127        inp.write(str(nfiles) + "\n")     # Number of files to merge.
128        inp.write(out_prefix + "\n")      # Prefix for the final output file:
129
130        for filename in files_to_merge:
131            inp.write(filename + "\n")   # List with the files to merge.
132
133        inp.write("1\n")                 # Option for merging q-points.
134
135        self.stdin_data = [s for s in inp.getvalue()]
136
137        with open(self.stdin_fname, "w") as fh:
138            fh.writelines(self.stdin_data)
139            # Force OS to write data to disk.
140            fh.flush()
141            os.fsync(fh.fileno())
142
143        self.execute(workdir)
144
145
146class Mrgddb(ExecWrapper):
147    _name = "mrgddb"
148
149    def merge(self, workdir, ddb_files, out_ddb, description, delete_source_ddbs=True):
150        """Merge DDB file, return the absolute path of the new database in workdir."""
151        # We work with absolute paths.
152        ddb_files = [os.path.abspath(s) for s in list_strings(ddb_files)]
153        if not os.path.isabs(out_ddb):
154            out_ddb = os.path.join(os.path.abspath(workdir), os.path.basename(out_ddb))
155
156        if self.verbose:
157            print("Will merge %d files into output DDB %s" % (len(ddb_files), out_ddb))
158            for i, f in enumerate(ddb_files):
159                print(" [%d] %s" % (i, f))
160
161        # Handle the case of a single file since mrgddb uses 1 to denote GS files!
162        if len(ddb_files) == 1:
163            with open(ddb_files[0], "r") as inh, open(out_ddb, "w") as out:
164                for line in inh:
165                    out.write(line)
166            return out_ddb
167
168        self.stdin_fname, self.stdout_fname, self.stderr_fname = \
169            map(os.path.join, 3 * [os.path.abspath(workdir)], ["mrgddb.stdin", "mrgddb.stdout", "mrgddb.stderr"])
170
171        inp = StringIO()
172        inp.write(out_ddb + "\n")              # Name of the output file.
173        inp.write(str(description) + "\n")     # Description.
174        inp.write(str(len(ddb_files)) + "\n")  # Number of input DDBs.
175
176        # Names of the DDB files.
177        for fname in ddb_files:
178            inp.write(fname + "\n")
179
180        self.stdin_data = [s for s in inp.getvalue()]
181
182        with open(self.stdin_fname, "wt") as fh:
183            fh.writelines(self.stdin_data)
184            # Force OS to write data to disk.
185            fh.flush()
186            os.fsync(fh.fileno())
187
188        retcode = self.execute(workdir, exec_args=['--nostrict'])
189        if retcode == 0 and delete_source_ddbs:
190            # Remove ddb files.
191            for f in ddb_files:
192                try:
193                    os.remove(f)
194                except IOError:
195                    pass
196
197        return out_ddb
198
199
200class Mrgdvdb(ExecWrapper):
201    _name = "mrgdv"
202
203    def merge(self, workdir, pot_files, out_dvdb, delete_source=True):
204        """
205        Merge POT files containing 1st order DFPT potential
206        return the absolute path of the new database in workdir.
207
208        Args:
209            delete_source: True if POT1 files should be removed after (successful) merge.
210        """
211        # We work with absolute paths.
212        pot_files = [os.path.abspath(s) for s in list_strings(pot_files)]
213        if not os.path.isabs(out_dvdb):
214            out_dvdb = os.path.join(os.path.abspath(workdir), os.path.basename(out_dvdb))
215
216        if self.verbose:
217            print("Will merge %d files into output DVDB %s" % (len(pot_files), out_dvdb))
218            for i, f in enumerate(pot_files):
219                print(" [%d] %s" % (i, f))
220
221        # Handle the case of a single file since mrgddb uses 1 to denote GS files!
222        if len(pot_files) == 1:
223            with open(pot_files[0], "r") as inh, open(out_dvdb, "w") as out:
224                for line in inh:
225                    out.write(line)
226            return out_dvdb
227
228        self.stdin_fname, self.stdout_fname, self.stderr_fname = \
229            map(os.path.join, 3 * [os.path.abspath(workdir)], ["mrgdvdb.stdin", "mrgdvdb.stdout", "mrgdvdb.stderr"])
230
231        inp = StringIO()
232        inp.write(out_dvdb + "\n")             # Name of the output file.
233        inp.write(str(len(pot_files)) + "\n")  # Number of input POT files.
234
235        # Names of the POT files.
236        for fname in pot_files:
237            inp.write(fname + "\n")
238
239        self.stdin_data = [s for s in inp.getvalue()]
240
241        with open(self.stdin_fname, "wt") as fh:
242            fh.writelines(self.stdin_data)
243            # Force OS to write data to disk.
244            fh.flush()
245            os.fsync(fh.fileno())
246
247        retcode = self.execute(workdir)
248        if retcode == 0 and delete_source:
249            # Remove pot files.
250            for f in pot_files:
251                try:
252                    os.remove(f)
253                except IOError:
254                    pass
255
256        return out_dvdb
257
258
259class Cut3D(ExecWrapper):
260    _name = "cut3d"
261
262    def cut3d(self, cut3d_input, workdir):
263        """
264        Runs cut3d with a Cut3DInput
265
266        Args:
267            cut3d_input: a Cut3DInput object.
268            workdir: directory where cut3d is executed.
269
270        Returns:
271            (string) absolute path to the standard output of the cut3d execution.
272            (string) absolute path to the output filepath. None if output is required.
273        """
274        self.stdin_fname, self.stdout_fname, self.stderr_fname = \
275            map(os.path.join, 3 * [os.path.abspath(workdir)], ["cut3d.stdin", "cut3d.stdout", "cut3d.stderr"])
276
277        cut3d_input.write(self.stdin_fname)
278
279        retcode = self._execute(workdir, with_mpirun=False)
280
281        if retcode != 0:
282            raise RuntimeError("Error while running cut3d in %s." % workdir)
283
284        output_filepath = cut3d_input.output_filepath
285
286        if output_filepath is not None:
287            if not os.path.isabs(output_filepath):
288                output_filepath = os.path.abspath(os.path.join(workdir, output_filepath))
289
290            if not os.path.isfile(output_filepath):
291                raise RuntimeError("The file was not converted correctly in %s." % workdir)
292
293        return self.stdout_fname, output_filepath
294
295
296class Fold2Bloch(ExecWrapper):
297    """Wrapper for fold2Bloch Fortran executable."""
298    _name = "fold2Bloch"
299
300    def unfold(self, wfkpath, folds, workdir=None):
301        import tempfile
302        workdir = tempfile.mkdtemp() if workdir is None else workdir
303
304        self.stdin_fname = None
305        self.stdout_fname, self.stderr_fname = \
306            map(os.path.join, 2 * [workdir], ["fold2bloch.stdout", "fold2bloch.stderr"])
307
308        folds = np.array(folds, dtype=np.int).flatten()
309        if len(folds) not in (3, 9):
310            raise ValueError("Expecting 3 ints or 3x3 matrix but got %s" % (str(folds)))
311        fold_arg = ":".join((str(f) for f in folds))
312        wfkpath = os.path.abspath(wfkpath)
313        if not os.path.isfile(wfkpath):
314            raise RuntimeError("WFK file `%s` does not exist in %s" % (wfkpath, workdir))
315
316        # Usage: $ fold2Bloch file_WFK x:y:z (folds)
317        retcode = self.execute(workdir, exec_args=[wfkpath, fold_arg])
318        if retcode:
319            print("stdout:")
320            print(self.stdout_data)
321            print("stderr:")
322            print(self.stderr_data)
323            raise RuntimeError("fold2bloch returned %s in %s" % (retcode, workdir))
324
325        filepaths = [f for f in os.listdir(workdir) if f.endswith("_FOLD2BLOCH.nc")]
326        if len(filepaths) != 1:
327            raise RuntimeError("Cannot find *_FOLD2BLOCH.nc file in: %s" % str(os.listdir(workdir)))
328
329        return os.path.join(workdir, filepaths[0])
330