1# coding: utf-8
3A Flow is a container for Works, and works consist of tasks.
4Flows are the final objects that can be dumped directly to a pickle file on disk
5Flows are executed using abirun (abipy).
7import os
8import sys
9import time
10import collections
11import warnings
12import shutil
13import tempfile
14import numpy as np
16from io import StringIO
17from pprint import pprint
18from tabulate import tabulate
19from pydispatch import dispatcher
20from collections import OrderedDict
21from monty.collections import dict2namedtuple
22from monty.string import list_strings, is_string, make_banner
23from monty.operator import operator_from_str
24from monty.io import FileLock
25from monty.pprint import draw_tree
26from monty.termcolor import cprint, colored, cprint_map, get_terminal_size
27from monty.inspect import find_top_pyfile
28from monty.json import MSONable
29from pymatgen.util.serialization import pmg_pickle_load, pmg_pickle_dump, pmg_serialize
30from pymatgen.core.units import Memory
31from pymatgen.util.io_utils import AtomicFile
32from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt
33from abipy.tools.printing import print_dataframe
34from abipy.flowtk import wrappers
35from .nodes import Status, Node, NodeError, NodeResults, Dependency, GarbageCollector, check_spectator
36from .tasks import ScfTask, TaskManager, FixQueueCriticalError
37from .utils import File, Directory, Editor
38from .works import NodeContainer, Work, BandStructureWork, PhononWork, BecWork, G0W0Work, QptdmWork, DteWork
39from .events import EventsParser
41__author__ = "Matteo Giantomassi"
42__copyright__ = "Copyright 2013, The Materials Project"
43__version__ = "0.1"
44__maintainer__ = "Matteo Giantomassi"
47__all__ = [
48    "Flow",
49    "G0W0WithQptdmFlow",
50    "bandstructure_flow",
51    "g0w0_flow",
55def as_set(obj):
56    """
57    Convert obj into a set, returns None if obj is None.
59    >>> assert as_set(None) is None and as_set(1) == set([1]) and as_set(range(1,3)) == set([1, 2])
60    """
61    if obj is None or isinstance(obj, collections.abc.Set):
62        return obj
64    if not isinstance(obj, collections.abc.Iterable):
65        return set((obj,))
66    else:
67        return set(obj)
70class FlowResults(NodeResults):
72    JSON_SCHEMA = NodeResults.JSON_SCHEMA.copy()
73    #JSON_SCHEMA["properties"] = {
74    #    "queries": {"type": "string", "required": True},
75    #}
77    @classmethod
78    def from_node(cls, flow):
79        """Initialize an instance from a Work instance."""
80        new = super().from_node(flow)
82        # Will put all files found in outdir in GridFs
83        d = {os.path.basename(f): f for f in flow.outdir.list_filepaths()}
85        # Add the pickle file.
86        d["pickle"] = flow.pickle_file if flow.pickle_protocol != 0 else (flow.pickle_file, "t")
87        new.add_gridfs_files(**d)
89        return new
92class FlowError(NodeError):
93    """Base Exception for :class:`Node` methods"""
96class Flow(Node, NodeContainer, MSONable):
97    """
98    This object is a container of work. Its main task is managing the
99    possible inter-dependencies among the work and the creation of
100    dynamic workflows that are generated by callbacks registered by the user.
102    Attributes:
104        creation_date: String with the creation_date
105        pickle_protocol: Protocol for Pickle database (default: -1 i.e. latest protocol)
107    Important methods for constructing flows:
109        register_work: register (add) a work to the flow
110        resister_task: register a work that contains only this task returns the work
111        allocate: propagate the workdir and manager of the flow to all the registered tasks
112        build:
113        build_and_pickle_dump:
114    """
115    VERSION = "0.1"
116    PICKLE_FNAME = "__AbinitFlow__.pickle"
118    Error = FlowError
120    Results = FlowResults
122    @classmethod
123    def from_inputs(cls, workdir, inputs, manager=None, pickle_protocol=-1, task_class=ScfTask,
124                    work_class=Work, remove=False):
126        """
127        Construct a simple flow from a list of inputs. The flow contains a single Work with
128        tasks whose class is given by task_class.
130        .. warning::
132            Don't use this interface if you have dependencies among the tasks.
134        Args:
135            workdir: String specifying the directory where the works will be produced.
136            inputs: List of inputs.
137            manager: |TaskManager| object responsible for the submission of the jobs.
138                If manager is None, the object is initialized from the yaml file
139                located either in the working directory or in the user configuration dir.
140            pickle_protocol: Pickle protocol version used for saving the status of the object.
141                -1 denotes the latest version supported by the python interpreter.
142            task_class: The class of the |Task|.
143            work_class: The class of the |Work|.
144            remove: attempt to remove working directory `workdir` if directory already exists.
145        """
146        if not isinstance(inputs, (list, tuple)): inputs = [inputs]
148        flow = cls(workdir, manager=manager, pickle_protocol=pickle_protocol, remove=remove)
149        work = work_class()
150        for inp in inputs:
151            work.register(inp, task_class=task_class)
152        flow.register_work(work)
154        return flow.allocate()
156    @classmethod
157    def as_flow(cls, obj):
158        """Convert obj into a Flow. Accepts filepath, dict, or Flow object."""
159        if isinstance(obj, cls): return obj
160        if is_string(obj):
161            return cls.pickle_load(obj)
162        elif isinstance(obj, collections.abc.Mapping):
163            return cls.from_dict(obj)
164        else:
165            raise TypeError("Don't know how to convert type %s into a Flow" % type(obj))
167    def __init__(self, workdir, manager=None, pickle_protocol=-1, remove=False):
168        """
169        Args:
170            workdir: String specifying the directory where the works will be produced.
171                     if workdir is None, the initialization of the working directory
172                     is performed by flow.allocate(workdir).
173            manager: |TaskManager| object responsible for the submission of the jobs.
174                     If manager is None, the object is initialized from the yaml file
175                     located either in the working directory or in the user configuration dir.
176            pickle_protocol: Pickle protocol version used for saving the status of the object.
177                          -1 denotes the latest version supported by the python interpreter.
178            remove: attempt to remove working directory `workdir` if directory already exists.
179        """
180        super().__init__()
182        if workdir is not None:
183            if remove and os.path.exists(workdir): shutil.rmtree(workdir)
184            self.set_workdir(workdir)
186        self.creation_date = time.asctime()
188        if manager is None: manager = TaskManager.from_user_config()
189        self.manager = manager.deepcopy()
191        # List of works.
192        self._works = []
194        self._waited = 0
196        # List of callbacks that must be executed when the dependencies reach S_OK
197        self._callbacks = []
199        # Install default list of handlers at the flow level.
200        # Users can override the default list by calling flow.install_event_handlers in the script.
201        # Example:
202        #
203        #    # flow level (common case)
204        #    flow.install_event_handlers(handlers=my_handlers)
205        #
206        #    # task level (advanced mode)
207        #    flow[0][0].install_event_handlers(handlers=my_handlers)
208        #
209        self.install_event_handlers()
211        self.pickle_protocol = int(pickle_protocol)
213        # ID used to access mongodb
214        self._mongo_id = None
216        # Save the location of the script used to generate the flow.
217        # This trick won't work if we are running with nosetests, py.test etc
218        pyfile = find_top_pyfile()
219        if "python" in pyfile or "ipython" in pyfile: pyfile = "<" + pyfile + ">"
220        self.set_pyfile(pyfile)
222        # TODO
223        # Signal slots: a dictionary with the list
224        # of callbacks indexed by node_id and SIGNAL_TYPE.
225        # When the node changes its status, it broadcast a signal.
226        # The flow is listening to all the nodes of the calculation
227        # [node_id][SIGNAL] = list_of_signal_handlers
228        #self._sig_slots =  slots = {}
229        #for work in self:
230        #    slots[work] = {s: [] for s in work.S_ALL}
232        #for task in self.iflat_tasks():
233        #    slots[task] = {s: [] for s in work.S_ALL}
235        self.on_all_ok_num_calls = 0
237    @pmg_serialize
238    def as_dict(self, **kwargs):
239        """
240        JSON serialization, note that we only need to save
241        a string with the working directory since the object will be
242        reconstructed from the pickle file located in workdir
243        """
244        return {"workdir": self.workdir}
246    # This is needed for fireworks.
247    to_dict = as_dict
249    @classmethod
250    def from_dict(cls, d, **kwargs):
251        """Reconstruct the flow from the pickle file."""
252        return cls.pickle_load(d["workdir"], **kwargs)
254    @classmethod
255    def temporary_flow(cls, manager=None):
256        """Return a Flow in a temporary directory. Useful for unit tests."""
257        return cls(workdir=tempfile.mkdtemp(), manager=manager)
259    def set_workdir(self, workdir, chroot=False):
260        """
261        Set the working directory. Cannot be set more than once unless chroot is True
262        """
263        if not chroot and hasattr(self, "workdir") and self.workdir != workdir:
264            raise ValueError("self.workdir != workdir: %s, %s" % (self.workdir,  workdir))
266        # Directories with (input|output|temporary) data.
267        self.workdir = os.path.abspath(workdir)
268        self.indir = Directory(os.path.join(self.workdir, "indata"))
269        self.outdir = Directory(os.path.join(self.workdir, "outdata"))
270        self.tmpdir = Directory(os.path.join(self.workdir, "tmpdata"))
271        self.wdir = Directory(self.workdir)
273    def reload(self):
274        """
275        Reload the flow from the pickle file. Used when we are monitoring the flow
276        executed by the scheduler. In this case, indeed, the flow might have been changed
277        by the scheduler and we have to reload the new flow in memory.
278        """
279        new = self.__class__.pickle_load(self.workdir)
280        self = new
282    @classmethod
283    def pickle_load(cls, filepath, spectator_mode=True, remove_lock=False):
284        """
285        Loads the object from a pickle file and performs initial setup.
287        Args:
288            filepath: Filename or directory name. It filepath is a directory, we
289                scan the directory tree starting from filepath and we
290                read the first pickle database. Raise RuntimeError if multiple
291                databases are found.
292            spectator_mode: If True, the nodes of the flow are not connected by signals.
293                This option is usually used when we want to read a flow
294                in read-only mode and we want to avoid callbacks that can change the flow.
295            remove_lock:
296                True to remove the file lock if any (use it carefully).
297        """
298        if os.path.isdir(filepath):
299            # Walk through each directory inside path and find the pickle database.
300            for dirpath, dirnames, filenames in os.walk(filepath):
301                fnames = [f for f in filenames if f == cls.PICKLE_FNAME]
302                if fnames:
303                    if len(fnames) == 1:
304                        filepath = os.path.join(dirpath, fnames[0])
305                        break  # Exit os.walk
306                    else:
307                        err_msg = "Found multiple databases:\n %s" % str(fnames)
308                        raise RuntimeError(err_msg)
309            else:
310                err_msg = "Cannot find %s inside directory %s" % (cls.PICKLE_FNAME, filepath)
311                raise ValueError(err_msg)
313        if remove_lock and os.path.exists(filepath + ".lock"):
314            try:
315                os.remove(filepath + ".lock")
316            except Exception:
317                pass
319        with FileLock(filepath):
320            with open(filepath, "rb") as fh:
321                flow = pmg_pickle_load(fh)
323        # Check if versions match.
324        if flow.VERSION != cls.VERSION:
325            msg = ("File flow version %s != latest version %s\n."
326                   "Regenerate the flow to solve the problem " % (flow.VERSION, cls.VERSION))
327            warnings.warn(msg)
329        flow.set_spectator_mode(spectator_mode)
331        # Recompute the status of each task since tasks that
332        # have been submitted previously might be completed.
333        flow.check_status()
334        return flow
336    # Handy alias
337    from_file = pickle_load
339    @classmethod
340    def pickle_loads(cls, s):
341        """Reconstruct the flow from a string."""
342        strio = StringIO()
343        strio.write(s)
344        strio.seek(0)
345        flow = pmg_pickle_load(strio)
346        return flow
348    def get_panel(self):
349        """Build panel with widgets to interact with the |Flow| either in a notebook or in panel app."""
350        from abipy.panels.flows import FlowPanel
351        return FlowPanel(self).get_panel()
353    def __len__(self):
354        return len(self.works)
356    def __iter__(self):
357        return self.works.__iter__()
359    def __getitem__(self, slice):
360        return self.works[slice]
362    def set_pyfile(self, pyfile):
363        """
364        Set the path of the python script used to generate the flow.
366        .. Example:
368            flow.set_pyfile(__file__)
369        """
370        # TODO: Could use a frame hack to get the caller outside abinit
371        # so that pyfile is automatically set when we __init__ it!
372        self._pyfile = os.path.abspath(pyfile)
374    @property
375    def pyfile(self):
376        """
377        Absolute path of the python script used to generate the flow. Set by `set_pyfile`
378        """
379        try:
380            return self._pyfile
381        except AttributeError:
382            return None
384    @property
385    def pid_file(self):
386        """The path of the pid file created by PyFlowScheduler."""
387        return os.path.join(self.workdir, "_PyFlowScheduler.pid")
389    @property
390    def has_scheduler(self):
391        """True if there's a scheduler running the flow."""
392        return os.path.exists(self.pid_file)
394    def check_pid_file(self):
395        """
396        This function checks if we are already running the |Flow| with a :class:`PyFlowScheduler`.
397        Raises: Flow.Error if the pid file of the scheduler exists.
398        """
399        if not os.path.exists(self.pid_file):
400            return 0
402        self.show_status()
403        raise self.Error("""\n\
404            pid_file
405            %s
406            already exists. There are two possibilities:
408               1) There's an another instance of PyFlowScheduler running
409               2) The previous scheduler didn't exit in a clean way
411            To solve case 1:
412               Kill the previous scheduler (use 'kill pid' where pid is the number reported in the file)
413               Then you can restart the new scheduler.
415            To solve case 2:
416               Remove the pid_file and restart the scheduler.
418            Exiting""" % self.pid_file)
420    @property
421    def pickle_file(self):
422        """The path of the pickle file."""
423        return os.path.join(self.workdir, self.PICKLE_FNAME)
425    @property
426    def mongo_id(self):
427        return self._mongo_id
429    @mongo_id.setter
430    def mongo_id(self, value):
431        if self.mongo_id is not None:
432            raise RuntimeError("Cannot change mongo_id %s" % self.mongo_id)
433        self._mongo_id = value
435    #def mongodb_upload(self, **kwargs):
436    #    from abiflows.core.scheduler import FlowUploader
437    #    FlowUploader().upload(self, **kwargs)
439    def validate_json_schema(self):
440        """Validate the JSON schema. Return list of errors."""
441        errors = []
443        for work in self:
444            for task in work:
445                if not task.get_results().validate_json_schema():
446                    errors.append(task)
447            if not work.get_results().validate_json_schema():
448                errors.append(work)
449        if not self.get_results().validate_json_schema():
450            errors.append(self)
452        return errors
454    def get_mongo_info(self):
455        """
456        Return a JSON dictionary with information on the flow.
457        Mainly used for constructing the info section in `FlowEntry`.
458        The default implementation is empty. Subclasses must implement it
459        """
460        return {}
462    def mongo_assimilate(self):
463        """
464        This function is called by client code when the flow is completed
465        Return a JSON dictionary with the most important results produced
466        by the flow. The default implementation is empty. Subclasses must implement it
467        """
468        return {}
470    @property
471    def works(self):
472        """List of |Work| objects contained in self.."""
473        return self._works
475    @property
476    def all_ok(self):
477        """True if all the tasks in works have reached `S_OK`."""
478        all_ok = all(work.all_ok for work in self)
479        if all_ok:
480            all_ok = self.on_all_ok()
481        return all_ok
483    @property
484    def num_tasks(self):
485        """Total number of tasks"""
486        return len(list(self.iflat_tasks()))
488    @property
489    def errored_tasks(self):
490        """List of errored tasks."""
491        etasks = []
492        for status in [self.S_ERROR, self.S_QCRITICAL, self.S_ABICRITICAL]:
493            etasks.extend(list(self.iflat_tasks(status=status)))
495        return set(etasks)
497    @property
498    def num_errored_tasks(self):
499        """The number of tasks whose status is `S_ERROR`."""
500        return len(self.errored_tasks)
502    @property
503    def unconverged_tasks(self):
504        """List of unconverged tasks."""
505        return list(self.iflat_tasks(status=self.S_UNCONVERGED))
507    @property
508    def num_unconverged_tasks(self):
509        """The number of tasks whose status is `S_UNCONVERGED`."""
510        return len(self.unconverged_tasks)
512    @property
513    def status_counter(self):
514        """
515        Returns a :class:`Counter` object that counts the number of tasks with
516        given status (use the string representation of the status as key).
517        """
518        # Count the number of tasks with given status in each work.
519        counter = self[0].status_counter
520        for work in self[1:]:
521            counter += work.status_counter
523        return counter
525    @property
526    def ncores_reserved(self):
527        """
528        Returns the number of cores reserved in this moment.
529        A core is reserved if the task is not running but
530        we have submitted the task to the queue manager.
531        """
532        return sum(work.ncores_reserved for work in self)
534    @property
535    def ncores_allocated(self):
536        """
537        Returns the number of cores allocated in this moment.
538        A core is allocated if it's running a task or if we have
539        submitted a task to the queue manager but the job is still pending.
540        """
541        return sum(work.ncores_allocated for work in self)
543    @property
544    def ncores_used(self):
545        """
546        Returns the number of cores used in this moment.
547        A core is used if there's a job that is running on it.
548        """
549        return sum(work.ncores_used for work in self)
551    @property
552    def has_chrooted(self):
553        """
554        Returns a string that evaluates to True if we have changed
555        the workdir for visualization purposes e.g. we are using sshfs.
556        to mount the remote directory where the `Flow` is located.
557        The string gives the previous workdir of the flow.
558        """
559        try:
560            return self._chrooted_from
561        except AttributeError:
562            return ""
564    def chroot(self, new_workdir):
565        """
566        Change the workir of the |Flow|. Mainly used for
567        allowing the user to open the GUI on the local host
568        and access the flow from remote via sshfs.
570        .. note::
571            Calling this method will make the flow go in read-only mode.
572        """
573        self._chrooted_from = self.workdir
574        self.set_workdir(new_workdir, chroot=True)
576        for i, work in enumerate(self):
577            new_wdir = os.path.join(self.workdir, "w" + str(i))
578            work.chroot(new_wdir)
580    def groupby_status(self):
581        """
582        Returns a ordered dictionary mapping the task status to
583        the list of named tuples (task, work_index, task_index).
584        """
585        Entry = collections.namedtuple("Entry", "task wi ti")
586        d = collections.defaultdict(list)
588        for task, wi, ti in self.iflat_tasks_wti():
589            d[task.status].append(Entry(task, wi, ti))
591        # Sort keys according to their status.
592        return OrderedDict([(k, d[k]) for k in sorted(list(d.keys()))])
594    def groupby_task_class(self):
595        """
596        Returns a dictionary mapping the task class to the list of tasks in the flow
597        """
598        # Find all Task classes
599        class2tasks = OrderedDict()
600        for task in self.iflat_tasks():
601            cls = task.__class__
602            if cls not in class2tasks: class2tasks[cls] = []
603            class2tasks[cls].append(task)
605        return class2tasks
607    def iflat_nodes(self, status=None, op="==", nids=None):
608        """
609        Generators that produces a flat sequence of nodes.
610        if status is not None, only the tasks with the specified status are selected.
611        nids is an optional list of node identifiers used to filter the nodes.
612        """
613        nids = as_set(nids)
615        if status is None:
616            if not (nids and self.node_id not in nids):
617                yield self
619            for work in self:
620                if nids and work.node_id not in nids: continue
621                yield work
622                for task in work:
623                    if nids and task.node_id not in nids: continue
624                    yield task
625        else:
626            # Get the operator from the string.
627            op = operator_from_str(op)
629            # Accept Task.S_FLAG or string.
630            status = Status.as_status(status)
632            if not (nids and self.node_id not in nids):
633                if op(self.status, status): yield self
635            for wi, work in enumerate(self):
636                if nids and work.node_id not in nids: continue
637                if op(work.status, status): yield work
639                for ti, task in enumerate(work):
640                    if nids and task.node_id not in nids: continue
641                    if op(task.status, status): yield task
643    def node_from_nid(self, nid):
644        """Return the node in the `Flow` with the given `nid` identifier"""
645        for node in self.iflat_nodes():
646            if node.node_id == nid: return node
647        raise ValueError("Cannot find node with node id: %s" % nid)
649    def iflat_tasks_wti(self, status=None, op="==", nids=None):
650        """
651        Generator to iterate over all the tasks of the `Flow`.
652        Yields:
654            (task, work_index, task_index)
656        If status is not None, only the tasks whose status satisfies
657        the condition (task.status op status) are selected
658        status can be either one of the flags defined in the |Task| class
659        (e.g Task.S_OK) or a string e.g "S_OK"
660        nids is an optional list of node identifiers used to filter the tasks.
661        """
662        return self._iflat_tasks_wti(status=status, op=op, nids=nids, with_wti=True)
664    def iflat_tasks(self, status=None, op="==", nids=None):
665        """
666        Generator to iterate over all the tasks of the |Flow|.
668        If status is not None, only the tasks whose status satisfies
669        the condition (task.status op status) are selected
670        status can be either one of the flags defined in the |Task| class
671        (e.g Task.S_OK) or a string e.g "S_OK"
672        nids is an optional list of node identifiers used to filter the tasks.
673        """
674        return self._iflat_tasks_wti(status=status, op=op, nids=nids, with_wti=False)
676    def _iflat_tasks_wti(self, status=None, op="==", nids=None, with_wti=True):
677        """
678        Generators that produces a flat sequence of task.
679        if status is not None, only the tasks with the specified status are selected.
680        nids is an optional list of node identifiers used to filter the tasks.
682        Returns:
683            (task, work_index, task_index) if with_wti is True else task
684        """
685        nids = as_set(nids)
687        if status is None:
688            for wi, work in enumerate(self):
689                for ti, task in enumerate(work):
690                    if nids and task.node_id not in nids: continue
691                    if with_wti:
692                        yield task, wi, ti
693                    else:
694                        yield task
696        else:
697            # Get the operator from the string.
698            op = operator_from_str(op)
700            # Accept Task.S_FLAG or string.
701            status = Status.as_status(status)
703            for wi, work in enumerate(self):
704                for ti, task in enumerate(work):
705                    if nids and task.node_id not in nids: continue
706                    if op(task.status, status):
707                        if with_wti:
708                            yield task, wi, ti
709                        else:
710                            yield task
712    def abivalidate_inputs(self):
713        """
714        Run ABINIT in dry mode to validate all the inputs of the flow.
716        Return:
717            (isok, tuples)
719            isok is True if all inputs are ok.
720            tuples is List of `namedtuple` objects, one for each task in the flow.
721            Each namedtuple has the following attributes:
723                retcode: Return code. 0 if OK.
724                log_file:  log file of the Abinit run, use log_file.read() to access its content.
725                stderr_file: stderr file of the Abinit run. use stderr_file.read() to access its content.
727        Raises:
728            `RuntimeError` if executable is not in $PATH.
729        """
730        if not self.allocated:
731            self.allocate()
733        isok, tuples = True, []
734        for task in self.iflat_tasks():
735            t = task.input.abivalidate()
736            if t.retcode != 0: isok = False
737            tuples.append(t)
739        return isok, tuples
741    def check_dependencies(self):
742        """Test the dependencies of the nodes for possible deadlocks."""
743        deadlocks = []
745        for task in self.iflat_tasks():
746            for dep in task.deps:
747                if dep.node.depends_on(task):
748                    deadlocks.append((task, dep.node))
750        if deadlocks:
751            lines = ["Detect wrong list of dependecies that will lead to a deadlock:"]
752            lines.extend(["%s <--> %s" % nodes for nodes in deadlocks])
753            raise RuntimeError("\n".join(lines))
755    def find_deadlocks(self):
756        """
757        This function detects deadlocks
759        Return:
760            named tuple with the tasks grouped in: deadlocks, runnables, running
761        """
762        # Find jobs that can be submitted and and the jobs that are already in the queue.
763        runnables = []
764        for work in self:
765            runnables.extend(work.fetch_alltasks_to_run())
766        runnables.extend(list(self.iflat_tasks(status=self.S_SUB)))
768        # Running jobs.
769        running = list(self.iflat_tasks(status=self.S_RUN))
771        # Find deadlocks.
772        err_tasks = self.errored_tasks
773        deadlocked = []
774        if err_tasks:
775            for task in self.iflat_tasks():
776                if any(task.depends_on(err_task) for err_task in err_tasks):
777                    deadlocked.append(task)
779        return dict2namedtuple(deadlocked=deadlocked, runnables=runnables, running=running)
781    def check_status(self, **kwargs):
782        """
783        Check the status of the works in self.
785        Args:
786            show: True to show the status of the flow.
787            kwargs: keyword arguments passed to show_status
788        """
789        for work in self:
790            work.check_status()
792        if kwargs.pop("show", False):
793            self.show_status(**kwargs)
795    @property
796    def status(self):
797        """The status of the |Flow| i.e. the minimum of the status of its tasks and its works"""
798        return min(work.get_all_status(only_min=True) for work in self)
800    #def restart_unconverged_tasks(self, max_nlauch, excs):
801    #    nlaunch = 0
802    #    for task in self.unconverged_tasks:
803    #        try:
804    #            self.history.info("Flow will try restart task %s" % task)
805    #            fired = task.restart()
806    #            if fired:
807    #                nlaunch += 1
808    #                max_nlaunch -= 1
810    #                if max_nlaunch == 0:
811    #                    self.history.info("Restart: too many jobs in the queue, returning")
812    #                    self.pickle_dump()
813    #                    return nlaunch, max_nlaunch
815    #        except task.RestartError:
816    #            excs.append(straceback())
818    #    return nlaunch, max_nlaunch
820    def fix_abicritical(self):
821        """
822        This function tries to fix critical events originating from ABINIT.
823        Returns the number of tasks that have been fixed.
824        """
825        count = 0
826        for task in self.iflat_tasks(status=self.S_ABICRITICAL):
827            count += task.fix_abicritical()
829        return count
831    def fix_queue_critical(self):
832        """
833        This function tries to fix critical events originating from the queue submission system.
835        Returns the number of tasks that have been fixed.
836        """
837        count = 0
838        for task in self.iflat_tasks(status=self.S_QCRITICAL):
839            self.history.info("Will try to fix task %s" % str(task))
840            try:
841                print(task.fix_queue_critical())
842                count += 1
843            except FixQueueCriticalError:
844                self.history.info("Not able to fix task %s" % task)
846        return count
848    def show_info(self, **kwargs):
849        """
850        Print info on the flow i.e. total number of tasks, works, tasks grouped by class.
852        Example:
854            Task Class      Number
855            ------------  --------
856            ScfTask              1
857            NscfTask             1
858            ScrTask              2
859            SigmaTask            6
860        """
861        stream = kwargs.pop("stream", sys.stdout)
863        lines = [str(self)]
864        app = lines.append
866        app("Number of works: %d, total number of tasks: %s" % (len(self), self.num_tasks))
867        app("Number of tasks with a given class:\n")
869        # Build Table
870        data = [[cls.__name__, len(tasks)]
871                for cls, tasks in self.groupby_task_class().items()]
872        app(str(tabulate(data, headers=["Task Class", "Number"])))
874        stream.write("\n".join(lines))
876    def compare_abivars(self, varnames, nids=None, wslice=None, printout=False, with_colors=False):
877        """
878        Print the input of the tasks to the given stream.
880        Args:
881            varnames:
882                List of Abinit variables. If not None, only the variable in varnames
883                are selected and printed.
884            nids: List of node identifiers. By defaults all nodes are shown
885            wslice: Slice object used to select works.
886            printout: True to print dataframe.
887            with_colors: True if task status should be colored.
888        """
889        varnames = [s.strip() for s in list_strings(varnames)]
890        index, rows = [], []
891        for task in self.select_tasks(nids=nids, wslice=wslice):
892            index.append(task.pos_str)
893            dstruct = task.input.structure.as_dict(fmt="abivars")
895            od = OrderedDict()
896            for vname in varnames:
897                value = task.input.get(vname, None)
898                if value is None: # maybe in structure?
899                    value = dstruct.get(vname, None)
900                od[vname] = value
902            od["task_class"] = task.__class__.__name__
903            od["status"] = task.status.colored if with_colors else str(task.status)
904            rows.append(od)
906        import pandas as pd
907        df = pd.DataFrame(rows, index=index)
908        if printout:
909            print_dataframe(df, title="Input variables:")
911        return df
913    def get_dims_dataframe(self, nids=None, printout=False, with_colors=False):
914        """
915        Analyze output files produced by the tasks. Print pandas DataFrame with dimensions.
917        Args:
918            nids: List of node identifiers. By defaults all nodes are shown
919            printout: True to print dataframe.
920            with_colors: True if task status should be colored.
921        """
922        abo_paths, index, status, abo_relpaths, task_classes, task_nids = [], [], [], [], [], []
924        for task in self.iflat_tasks(nids=nids):
925            if task.status not in (self.S_OK, self.S_RUN): continue
926            if not task.is_abinit_task: continue
928            abo_paths.append(task.output_file.path)
929            index.append(task.pos_str)
930            status.append(task.status.colored if with_colors else str(task.status))
931            abo_relpaths.append(os.path.relpath(task.output_file.relpath))
932            task_classes.append(task.__class__.__name__)
933            task_nids.append(task.node_id)
935        if not abo_paths: return
937        # Get dimensions from output files as well as walltime/cputime
938        from abipy.abio.outputs import AboRobot
939        robot = AboRobot.from_files(abo_paths)
940        df = robot.get_dims_dataframe(with_time=True, index=index)
942        # Add columns to the dataframe.
943        status = [str(s) for s in status]
944        df["task_class"] = task_classes
945        df["relpath"] = abo_relpaths
946        df["node_id"] = task_nids
947        df["status"] = status
949        if printout:
950            print_dataframe(df, title="Table with Abinit dimensions:\n")
952        return df
954    def compare_structures(self, nids=None, with_spglib=False, what="io", verbose=0,
955                           precision=3, printout=False, with_colors=False):
956        """
957        Analyze structures of the tasks (input and output structures if it's a relaxation task.
958        Print pandas DataFrame
960        Args:
961            nids: List of node identifiers. By defaults all nodes are shown
962            with_spglib: If True, spglib is invoked to get the spacegroup symbol and number
963            what (str): "i" for input structures, "o" for output structures.
964            precision: Floating point output precision (number of significant digits).
965                This is only a suggestion
966            printout: True to print dataframe.
967            with_colors: True if task status should be colored.
968        """
969        from abipy.core.structure import dataframes_from_structures
970        structures, index, status, max_forces, pressures, task_classes = [], [], [], [], [], []
972        def push_data(post, task, structure, cart_forces, pressure):
973            """Helper function to fill lists"""
974            index.append(task.pos_str + post)
975            structures.append(structure)
976            status.append(task.status.colored if with_colors else str(task.status))
977            if cart_forces is not None:
978                fmods = np.sqrt([np.dot(f, f) for f in cart_forces])
979                max_forces.append(fmods.max())
980            else:
981                max_forces.append(None)
982            pressures.append(pressure)
983            task_classes.append(task.__class__.__name__)
985        for task in self.iflat_tasks(nids=nids):
986            if "i" in what:
987                push_data("_in", task, task.input.structure, cart_forces=None, pressure=None)
989            if "o" not in what:
990                continue
992            # Add final structure, pressure and max force if relaxation task or GS task
993            if task.status in (task.S_RUN, task.S_OK):
994                if hasattr(task, "open_hist"):
995                    # Structural relaxations produce HIST.nc and we can get
996                    # the final structure or the structure of the last relaxation step.
997                    try:
998                        with task.open_hist() as hist:
999                            final_structure = hist.final_structure
1000                            stress_cart_tensors, pressures_hist = hist.reader.read_cart_stress_tensors()
1001                            forces = hist.reader.read_cart_forces(unit="eV ang^-1")[-1]
1002                            push_data("_out", task, final_structure, forces, pressures_hist[-1])
1003                    except Exception as exc:
1004                        cprint("Exception while opening HIST.nc file of task: %s\n%s" % (task, str(exc)), "red")
1006                elif hasattr(task, "open_gsr") and task.status == task.S_OK and task.input.get("iscf", 7) >= 0:
1007                    with task.open_gsr() as gsr:
1008                        forces = gsr.reader.read_cart_forces(unit="eV ang^-1")
1009                        push_data("_out", task, gsr.structure, forces, gsr.pressure)
1011        dfs = dataframes_from_structures(structures, index=index, with_spglib=with_spglib, cart_coords=False)
1013        if any(f is not None for f in max_forces):
1014            # Add pressure and forces to the dataframe
1015            dfs.lattice["P [GPa]"] = pressures
1016            dfs.lattice["Max|F| eV/ang"] = max_forces
1018        # Add columns to the dataframe.
1019        status = [str(s) for s in status]
1020        dfs.lattice["task_class"] = task_classes
1021        dfs.lattice["status"] = dfs.coords["status"] = status
1023        if printout:
1024            print_dataframe(dfs.lattice, title="Lattice parameters:", precision=precision)
1025            if verbose:
1026                print_dataframe(dfs.coords, title="Atomic positions (columns give the site index):")
1027            else:
1028                print("Use `--verbose` to print atoms.")
1030        return dfs
1032    def compare_ebands(self, nids=None, with_path=True, with_ibz=True, with_spglib=False, verbose=0,
1033                       precision=3, printout=False, with_colors=False):
1034        """
1035        Analyze electron bands produced by the tasks.
1036        Return pandas DataFrame and |ElectronBandsPlotter|.
1038        Args:
1039            nids: List of node identifiers. By default, all nodes are shown
1040            with_path: Select files with ebands along k-path.
1041            with_ibz: Select files with ebands in the IBZ.
1042            with_spglib: If True, spglib is invoked to get the spacegroup symbol and number
1043            precision: Floating point output precision (number of significant digits).
1044                This is only a suggestion
1045            printout: True to print dataframe.
1046            with_colors: True if task status should be colored.
1048        Return: (df, ebands_plotter)
1049        """
1050        ebands_list, index, status, ncfiles, task_classes, task_nids = [], [], [], [], [], []
1052        # Cannot use robots because ElectronBands can be found in different filetypes
1053        for task in self.iflat_tasks(nids=nids, status=self.S_OK):
1054            # Read ebands from GSR or SIGRES files.
1055            for ext in ("gsr", "sigres"):
1056                task_open_ncfile = getattr(task, "open_%s" % ext, None)
1057                if task_open_ncfile is not None: break
1058            else:
1059                continue
1061            try:
1062                with task_open_ncfile() as ncfile:
1063                    if not with_path and ncfile.ebands.kpoints.is_path: continue
1064                    if not with_ibz and ncfile.ebands.kpoints.is_ibz: continue
1065                    ebands_list.append(ncfile.ebands)
1066                    index.append(task.pos_str)
1067                    status.append(task.status.colored if with_colors else str(task.status))
1068                    ncfiles.append(os.path.relpath(ncfile.filepath))
1069                    task_classes.append(task.__class__.__name__)
1070                    task_nids.append(task.node_id)
1071            except Exception as exc:
1072                cprint("Exception while opening nc file of task: %s\n%s" % (task, str(exc)), "red")
1074        if not ebands_list: return (None, None)
1076        from abipy.electrons.ebands import dataframe_from_ebands
1077        df = dataframe_from_ebands(ebands_list, index=index, with_spglib=with_spglib)
1078        ncfiles = [os.path.relpath(p, self.workdir) for p in ncfiles]
1080        # Add columns to the dataframe.
1081        status = [str(s) for s in status]
1082        df["task_class"] = task_classes
1083        df["ncfile"] = ncfiles
1084        df["node_id"] = task_nids
1085        df["status"] = status
1087        if printout:
1088            from abipy.tools.printing import print_dataframe
1089            print_dataframe(df, title="KS electronic bands:", precision=precision)
1091        from abipy.electrons.ebands import ElectronBandsPlotter
1092        ebands_plotter = ElectronBandsPlotter(key_ebands=zip(ncfiles, ebands_list))
1094        return df, ebands_plotter
1096    def compare_hist(self, nids=None, with_spglib=False, verbose=0,
1097                     precision=3, printout=False, with_colors=False):
1098        """
1099        Analyze HIST nc files produced by the tasks. Print pandas DataFrame with final results.
1100        Return: (df, hist_plotter)
1102        Args:
1103            nids: List of node identifiers. By defaults all nodes are shown
1104            with_spglib: If True, spglib is invoked to get the spacegroup symbol and number
1105            precision: Floating point output precision (number of significant digits).
1106                This is only a suggestion
1107            printout: True to print dataframe.
1108            with_colors: True if task status should be colored.
1109        """
1110        hist_paths, index, status, ncfiles, task_classes, task_nids = [], [], [], [], [], []
1112        for task in self.iflat_tasks(nids=nids):
1113            if task.status not in (self.S_OK, self.S_RUN): continue
1114            hist_path = task.outdir.has_abiext("HIST")
1115            if not hist_path: continue
1117            hist_paths.append(hist_path)
1118            index.append(task.pos_str)
1119            status.append(task.status.colored if with_colors else str(task.status))
1120            ncfiles.append(os.path.relpath(hist_path))
1121            task_classes.append(task.__class__.__name__)
1122            task_nids.append(task.node_id)
1124        if not hist_paths: return (None, None)
1125        from abipy.dynamics.hist import HistRobot
1126        robot = HistRobot.from_files(hist_paths, labels=hist_paths)
1127        df = robot.get_dataframe(index=index, with_spglib=with_spglib)
1128        ncfiles = [os.path.relpath(p, self.workdir) for p in ncfiles]
1130        # Add columns to the dataframe.
1131        status = [str(s) for s in status]
1132        df["task_class"] = task_classes
1133        df["ncfile"] = ncfiles
1134        df["node_id"] = task_nids
1135        df["status"] = status
1137        if printout:
1138            title = "Table with final structures, pressures in GPa and force stats in eV/Ang:\n"
1139            from abipy.tools.printing import print_dataframe
1140            print_dataframe(df, title=title, precision=precision)
1142        return df, robot
1144    def show_summary(self, **kwargs):
1145        """
1146        Print a short summary with the status of the flow and a counter task_status --> number_of_tasks
1148        Args:
1149            stream: File-like object, Default: sys.stdout
1151        Example:
1153            Status       Count
1154            ---------  -------
1155            Completed       10
1157            <Flow, node_id=27163, workdir=flow_gwconv_ecuteps>, num_tasks=10, all_ok=True
1158        """
1159        stream = kwargs.pop("stream", sys.stdout)
1160        stream.write("\n")
1161        table = list(self.status_counter.items())
1162        s = tabulate(table, headers=["Status", "Count"])
1163        stream.write(s + "\n")
1164        stream.write("\n")
1165        stream.write("%s, num_tasks=%s, all_ok=%s\n" % (str(self), self.num_tasks, self.all_ok))
1166        stream.write("\n")
1168    def show_status(self, **kwargs):
1169        """
1170        Report the status of the works and the status of the different tasks on the specified stream.
1172        Args:
1173            stream: File-like object, Default: sys.stdout
1174            nids:  List of node identifiers. By defaults all nodes are shown
1175            wslice: Slice object used to select works.
1176            verbose: Verbosity level (default 0). > 0 to show only the works that are not finalized.
1177        """
1178        stream = kwargs.pop("stream", sys.stdout)
1179        nids = as_set(kwargs.pop("nids", None))
1180        wslice = kwargs.pop("wslice", None)
1181        verbose = kwargs.pop("verbose", 0)
1182        wlist = None
1183        if wslice is not None:
1184            # Convert range to list of work indices.
1185            wlist = list(range(wslice.start, wslice.step, wslice.stop))
1187        #has_colours = stream_has_colours(stream)
1188        has_colours = True
1189        red = "red" if has_colours else None
1191        for i, work in enumerate(self):
1192            if nids and work.node_id not in nids: continue
1193            print("", file=stream)
1194            cprint_map("Work #%d: %s, Finalized=%s" % (i, work, work.finalized), cmap={"True": "green"}, file=stream)
1195            if wlist is not None and i in wlist: continue
1196            if verbose == 0 and work.finalized:
1197                print("  Finalized works are not shown. Use verbose > 0 to force output.", file=stream)
1198                continue
1200            headers = ["Task", "Status", "Queue", "MPI|Omp|Gb",
1201                       "Warn|Com", "Class", "Sub|Rest|Corr", "Time",
1202                       "Node_ID"]
1203            table = []
1204            tot_num_errors = 0
1205            for task in work:
1206                if nids and task.node_id not in nids: continue
1207                task_name = os.path.basename(task.name)
1209                # FIXME: This should not be done here.
1210                # get_event_report should be called only in check_status
1211                # Parse the events in the main output.
1212                report = task.get_event_report()
1214                # Get time info (run-time or time in queue or None)
1215                stime = None
1216                timedelta = task.datetimes.get_runtime()
1217                if timedelta is not None:
1218                    stime = str(timedelta) + "R"
1219                else:
1220                    timedelta = task.datetimes.get_time_inqueue()
1221                    if timedelta is not None:
1222                        stime = str(timedelta) + "Q"
1224                events = "|".join(2*["NA"])
1225                if report is not None:
1226                    events = '{:>4}|{:>3}'.format(*map(str, (
1227                       report.num_warnings, report.num_comments)))
1229                para_info = '{:>4}|{:>3}|{:>3}'.format(*map(str, (
1230                   task.mpi_procs, task.omp_threads, "%.1f" % task.mem_per_proc.to("Gb"))))
1232                task_info = list(map(str, [task.__class__.__name__,
1233                                 (task.num_launches, task.num_restarts, task.num_corrections), stime, task.node_id]))
1235                qinfo = "None"
1236                if task.queue_id is not None:
1237                    qname = str(task.qname)
1238                    if not verbose:
1239                        qname = qname[:min(5, len(qname))]
1240                    qinfo = str(task.queue_id) + "@" + qname
1242                if task.status.is_critical:
1243                    tot_num_errors += 1
1244                    task_name = colored(task_name, red)
1246                if has_colours:
1247                    table.append([task_name, task.status.colored, qinfo,
1248                                  para_info, events] + task_info)
1249                else:
1250                    table.append([task_name, str(task.status), qinfo, events,
1251                                  para_info] + task_info)
1253            # Print table and write colorized line with the total number of errors.
1254            print(tabulate(table, headers=headers, tablefmt="grid"), file=stream)
1255            if tot_num_errors:
1256                cprint("Total number of errors: %d" % tot_num_errors, "red", file=stream)
1257            print("", file=stream)
1259        if self.all_ok:
1260            cprint("\nall_ok reached\n", "green", file=stream)
1262    def show_events(self, status=None, nids=None, stream=sys.stdout):
1263        """
1264        Print the Abinit events (ERRORS, WARNIING, COMMENTS) to stdout
1266        Args:
1267            status: if not None, only the tasks with this status are select
1268            nids: optional list of node identifiers used to filter the tasks.
1269            stream: File-like object, Default: sys.stdout
1270        """
1271        nrows, ncols = get_terminal_size()
1273        for task in self.iflat_tasks(status=status, nids=nids):
1274            report = task.get_event_report()
1275            if report:
1276                print(make_banner(str(task), width=ncols, mark="="), file=stream)
1277                print(report, file=stream)
1278                #report = report.filter_types()
1280    def show_corrections(self, status=None, nids=None, stream=sys.stdout):
1281        """
1282        Show the corrections applied to the flow at run-time.
1284        Args:
1285            status: if not None, only the tasks with this status are select.
1286            nids: optional list of node identifiers used to filter the tasks.
1287            stream: File-like object, Default: sys.stdout
1289        Return: The number of corrections found.
1290        """
1291        nrows, ncols = get_terminal_size()
1292        count = 0
1293        for task in self.iflat_tasks(status=status, nids=nids):
1294            if task.num_corrections == 0: continue
1295            count += 1
1296            print(make_banner(str(task), width=ncols, mark="="), file=stream)
1297            for corr in task.corrections:
1298                pprint(corr, stream=stream)
1300        if not count: print("No correction found.", file=stream)
1301        return count
1303    def show_history(self, status=None, nids=None, full_history=False, metadata=False, stream=sys.stdout):
1304        """
1305        Print the history of the flow to stream
1307        Args:
1308            status: if not None, only the tasks with this status are select
1309            full_history: Print full info set, including nodes with an empty history.
1310            nids: optional list of node identifiers used to filter the tasks.
1311            metadata: print history metadata (experimental)
1312            stream: File-like object, Default: sys.stdout
1313        """
1314        nrows, ncols = get_terminal_size()
1316        works_done = []
1317        # Loop on the tasks and show the history of the work is not in works_done
1318        for task in self.iflat_tasks(status=status, nids=nids):
1319            work = task.work
1321            if work not in works_done:
1322                works_done.append(work)
1323                if work.history or full_history:
1324                    cprint(make_banner(str(work), width=ncols, mark="="), file=stream, **work.status.color_opts)
1325                    print(work.history.to_string(metadata=metadata), file=stream)
1327            if task.history or full_history:
1328                cprint(make_banner(str(task), width=ncols, mark="="), file=stream, **task.status.color_opts)
1329                print(task.history.to_string(metadata=metadata), file=stream)
1331        # Print the history of the flow.
1332        if self.history or full_history:
1333            cprint(make_banner(str(self), width=ncols, mark="="), file=stream, **self.status.color_opts)
1334            print(self.history.to_string(metadata=metadata), file=stream)
1336    def show_inputs(self, varnames=None, nids=None, wslice=None, stream=sys.stdout):
1337        """
1338        Print the input of the tasks to the given stream.
1340        Args:
1341            varnames: List of Abinit variables. If not None, only the variable in varnames
1342                are selected and printed.
1343            nids: List of node identifiers. By defaults all nodes are shown
1344            wslice: Slice object used to select works.
1345            stream: File-like object, Default: sys.stdout
1346        """
1347        if varnames is not None:
1348            # Build dictionary varname --> [(task1, value1), (task2, value2), ...]
1349            varnames = [s.strip() for s in list_strings(varnames)]
1350            dlist = collections.defaultdict(list)
1351            for task in self.select_tasks(nids=nids, wslice=wslice):
1352                dstruct = task.input.structure.as_dict(fmt="abivars")
1354                for vname in varnames:
1355                    value = task.input.get(vname, None)
1356                    if value is None: # maybe in structure?
1357                        value = dstruct.get(vname, None)
1358                    if value is not None:
1359                        dlist[vname].append((task, value))
1361            for vname in varnames:
1362                tv_list = dlist[vname]
1363                if not tv_list:
1364                    stream.write("[%s]: Found 0 tasks with this variable\n" % vname)
1365                else:
1366                    stream.write("[%s]: Found %s tasks with this variable\n" % (vname, len(tv_list)))
1367                    for i, (task, value) in enumerate(tv_list):
1368                        stream.write("   %s --> %s\n" % (str(value), task))
1369                stream.write("\n")
1371        else:
1372            lines = []
1373            for task in self.select_tasks(nids=nids, wslice=wslice):
1374                s = task.make_input(with_header=True)
1376                # Add info on dependencies.
1377                if task.deps:
1378                    s += "\n\nDependencies:\n" + "\n".join(str(dep) for dep in task.deps)
1379                else:
1380                    s += "\n\nDependencies: None"
1382                lines.append(2*"\n" + 80 * "=" + "\n" + s + 2*"\n")
1384            stream.writelines(lines)
1386    def listext(self, ext, stream=sys.stdout):
1387        """
1388        Print to the given `stream` a table with the list of the output files
1389        with the given `ext` produced by the flow.
1390        """
1391        nodes_files = []
1392        for node in self.iflat_nodes():
1393            filepath = node.outdir.has_abiext(ext)
1394            if filepath:
1395                nodes_files.append((node, File(filepath)))
1397        if nodes_files:
1398            print("Found %s files with extension `%s` produced by the flow" % (len(nodes_files), ext), file=stream)
1400            table = [[f.relpath, "%.2f" % (f.get_stat().st_size / 1024**2),
1401                      node.node_id, node.__class__.__name__]
1402                     for node, f in nodes_files]
1403            print(tabulate(table, headers=["File", "Size [Mb]", "Node_ID", "Node Class"]), file=stream)
1405        else:
1406            print("No output file with extension %s has been produced by the flow" % ext, file=stream)
1408    def select_tasks(self, nids=None, wslice=None, task_class=None):
1409        """
1410        Return a list with a subset of tasks.
1412        Args:
1413            nids: List of node identifiers.
1414            wslice: Slice object used to select works.
1415            task_class: String or class used to select tasks. Ignored if None.
1417        .. note::
1419            nids and wslice are mutually exclusive.
1420            If no argument is provided, the full list of tasks is returned.
1421        """
1422        if nids is not None:
1423            assert wslice is None
1424            tasks = self.tasks_from_nids(nids)
1426        elif wslice is not None:
1427            tasks = []
1428            for work in self[wslice]:
1429                tasks.extend([t for t in work])
1430        else:
1431            # All tasks selected if no option is provided.
1432            tasks = list(self.iflat_tasks())
1434        # Filter by task class
1435        if task_class is not None:
1436            tasks = [t for t in tasks if t.isinstance(task_class)]
1438        return tasks
1440    def get_task_scfcycles(self, nids=None, wslice=None, task_class=None, exclude_ok_tasks=False):
1441        """
1442        Return list of (taks, scfcycle) tuples for all the tasks in the flow with a SCF algorithm
1443        e.g. electronic GS-SCF iteration, DFPT-SCF iterations etc.
1445        Args:
1446            nids: List of node identifiers.
1447            wslice: Slice object used to select works.
1448            task_class: String or class used to select tasks. Ignored if None.
1449            exclude_ok_tasks: True if only running tasks should be considered.
1451        Returns:
1452            List of `ScfCycle` subclass instances.
1453        """
1454        select_status = [self.S_RUN] if exclude_ok_tasks else [self.S_RUN, self.S_OK]
1455        tasks_cycles = []
1457        for task in self.select_tasks(nids=nids, wslice=wslice):
1458            # Fileter
1459            if task.status not in select_status or task.cycle_class is None:
1460                continue
1461            if task_class is not None and not task.isinstance(task_class):
1462                continue
1463            try:
1464                cycle = task.cycle_class.from_file(task.output_file.path)
1465                if cycle is not None:
1466                    tasks_cycles.append((task, cycle))
1467            except Exception:
1468                # This is intentionally ignored because from_file can fail for several reasons.
1469                pass
1471        return tasks_cycles
1473    def show_tricky_tasks(self, verbose=0, stream=sys.stdout):
1474        """
1475        Print list of tricky tasks i.e. tasks that have been restarted or
1476        launched more than once or tasks with corrections.
1478        Args:
1479            verbose: Verbosity level. If > 0, task history and corrections (if any) are printed.
1480            stream: File-like object. Default: sys.stdout
1481        """
1482        nids, tasks = [], []
1483        for task in self.iflat_tasks():
1484            if task.num_launches > 1 or any(n > 0 for n in (task.num_restarts, task.num_corrections)):
1485                nids.append(task.node_id)
1486                tasks.append(task)
1488        if not nids:
1489            cprint("Everything's fine, no tricky tasks found", color="green", file=stream)
1490        else:
1491            self.show_status(nids=nids, stream=stream)
1492            if not verbose:
1493                print("Use --verbose to print task history.", file=stream)
1494                return
1496            for nid, task in zip(nids, tasks):
1497                cprint(repr(task), **task.status.color_opts, stream=stream)
1498                self.show_history(nids=[nid], full_history=False, metadata=False, stream=stream)
1499                #if task.num_restarts:
1500                #    self.show_restarts(nids=[nid])
1501                if task.num_corrections:
1502                    self.show_corrections(nids=[nid], stream=stream)
1504    def inspect(self, nids=None, wslice=None, **kwargs):
1505        """
1506        Inspect the tasks (SCF iterations, Structural relaxation ...) and
1507        produces matplotlib plots.
1509        Args:
1510            nids: List of node identifiers.
1511            wslice: Slice object used to select works.
1512            kwargs: keyword arguments passed to `task.inspect` method.
1514        .. note::
1516            nids and wslice are mutually exclusive.
1517            If nids and wslice are both None, all tasks in self are inspected.
1519        Returns:
1520            List of `matplotlib` figures.
1521        """
1522        figs = []
1523        for task in self.select_tasks(nids=nids, wslice=wslice):
1524            if hasattr(task, "inspect"):
1525                fig = task.inspect(**kwargs)
1526                if fig is None:
1527                    cprint("Cannot inspect Task %s" % task, color="blue")
1528                else:
1529                    figs.append(fig)
1530            else:
1531                cprint("Task %s does not provide an inspect method" % task, color="blue")
1533        return figs
1535    def get_results(self, **kwargs):
1536        results = self.Results.from_node(self)
1537        results.update(self.get_dict_for_mongodb_queries())
1538        return results
1540    def get_dict_for_mongodb_queries(self):
1541        """
1542        This function returns a dictionary with the attributes that will be
1543        put in the mongodb document to facilitate the query.
1544        Subclasses may want to replace or extend the default behaviour.
1545        """
1546        d = {}
1547        return d
1548        # TODO
1549        all_structures = [task.input.structure for task in self.iflat_tasks()]
1550        all_pseudos = [task.input.pseudos for task in self.iflat_tasks()]
1552    def look_before_you_leap(self):
1553        """
1554        This method should be called before running the calculation to make
1555        sure that the most important requirements are satisfied.
1557        Return:
1558            List of strings with inconsistencies/errors.
1559        """
1560        errors = []
1562        try:
1563            self.check_dependencies()
1564        except self.Error as exc:
1565            errors.append(str(exc))
1567        if self.has_db:
1568            try:
1569                self.manager.db_connector.get_collection()
1570            except Exception as exc:
1571                errors.append("""
1572                    ERROR while trying to connect to the MongoDB database:
1573                        Exception:
1574                            %s
1575                        Connector:
1576                            %s
1577                    """ % (exc, self.manager.db_connector))
1579        return "\n".join(errors)
1581    @property
1582    def has_db(self):
1583        """True if flow uses `MongoDB` to store the results."""
1584        return self.manager.has_db
1586    def db_insert(self):
1587        """
1588        Insert results in the `MongDB` database.
1589        """
1590        assert self.has_db
1591        # Connect to MongoDb and get the collection.
1592        coll = self.manager.db_connector.get_collection()
1593        print("Mongodb collection %s with count %d", coll, coll.count())
1595        start = time.time()
1596        for work in self:
1597            for task in work:
1598                results = task.get_results()
1599                pprint(results)
1600                results.update_collection(coll)
1601            results = work.get_results()
1602            pprint(results)
1603            results.update_collection(coll)
1604        print("MongoDb update done in %s [s]" % time.time() - start)
1606        results = self.get_results()
1607        pprint(results)
1608        results.update_collection(coll)
1610        # Update the pickle file to save the mongo ids.
1611        self.pickle_dump()
1613        for d in coll.find():
1614            pprint(d)
1616    def tasks_from_nids(self, nids):
1617        """
1618        Return the list of tasks associated to the given list of node identifiers (nids).
1620        .. note::
1622            Invalid ids are ignored
1623        """
1624        if not isinstance(nids, collections.abc.Iterable): nids = [nids]
1626        n2task = {task.node_id: task for task in self.iflat_tasks()}
1627        return [n2task[n] for n in nids if n in n2task]
1629    def wti_from_nids(self, nids):
1630        """Return the list of (w, t) indices from the list of node identifiers nids."""
1631        return [task.pos for task in self.tasks_from_nids(nids)]
1633    def open_files(self, what="o", status=None, op="==", nids=None, editor=None):
1634        """
1635        Open the files of the flow inside an editor (command line interface).
1637        Args:
1638            what: string with the list of characters selecting the file type
1639                  Possible choices:
1641                    i ==> input_file,
1642                    o ==> output_file,
1643                    f ==> files_file,
1644                    j ==> job_file,
1645                    l ==> log_file,
1646                    e ==> stderr_file,
1647                    q ==> qout_file,
1648                    all ==> all files.
1650            status: if not None, only the tasks with this status are select
1651            op: status operator. Requires status. A task is selected
1652                if task.status op status evaluates to true.
1653            nids: optional list of node identifiers used to filter the tasks.
1654            editor: Select the editor. None to use the default editor ($EDITOR shell env var)
1655        """
1656        # Build list of files to analyze.
1657        files = []
1658        for task in self.iflat_tasks(status=status, op=op, nids=nids):
1659            lst = task.select_files(what)
1660            if lst:
1661                files.extend(lst)
1663        return Editor(editor=editor).edit_files(files)
1665    def parse_timing(self, nids=None):
1666        """
1667        Parse the timer data in the main output file(s) of Abinit.
1668        Requires timopt /= 0 in the input file (usually timopt = -1)
1670        Args:
1671            nids: optional list of node identifiers used to filter the tasks.
1673        Return: :class:`AbinitTimerParser` instance, None if error.
1674        """
1675        # Get the list of output files according to nids.
1676        paths = [task.output_file.path for task in self.iflat_tasks(nids=nids)]
1678        # Parse data.
1679        from abipy.flowtk.abitimer import AbinitTimerParser
1680        parser = AbinitTimerParser()
1681        read_ok = parser.parse(paths)
1682        if read_ok:
1683            return parser
1684        return None
1686    def show_abierrors(self, nids=None, stream=sys.stdout):
1687        """
1688        Write to the given stream the list of ABINIT errors for all tasks whose status is S_ABICRITICAL.
1690        Args:
1691            nids: optional list of node identifiers used to filter the tasks.
1692            stream: File-like object. Default: sys.stdout
1693        """
1694        lines = []
1695        app = lines.append
1697        for task in self.iflat_tasks(status=self.S_ABICRITICAL, nids=nids):
1698            header = "=== " + task.qout_file.path + "==="
1699            app(header)
1700            report = task.get_event_report()
1702            if report is not None:
1703                app("num_errors: %s, num_warnings: %s, num_comments: %s" % (
1704                    report.num_errors, report.num_warnings, report.num_comments))
1705                app("*** ERRORS ***")
1706                app("\n".join(str(e) for e in report.errors))
1707                app("*** BUGS ***")
1708                app("\n".join(str(b) for b in report.bugs))
1710            else:
1711                app("get_envent_report returned None!")
1713            app("=" * len(header) + 2*"\n")
1715        return stream.writelines(lines)
1717    def show_qouts(self, nids=None, stream=sys.stdout):
1718        """
1719        Write to the given stream the content of the queue output file for all tasks whose status is S_QCRITICAL.
1721        Args:
1722            nids: optional list of node identifiers used to filter the tasks.
1723            stream: File-like object. Default: sys.stdout
1724        """
1725        lines = []
1727        for task in self.iflat_tasks(status=self.S_QCRITICAL, nids=nids):
1728            header = "=== " + task.qout_file.path + "==="
1729            lines.append(header)
1730            if task.qout_file.exists:
1731                with open(task.qout_file.path, "rt") as fh:
1732                    lines += fh.readlines()
1733            else:
1734                lines.append("File does not exist!")
1736            lines.append("=" * len(header) + 2*"\n")
1738        return stream.writelines(lines)
1740    def debug(self, status=None, nids=None, stream=sys.stdout):
1741        """
1742        This method is usually used when the flow didn't completed succesfully
1743        It analyzes the files produced the tasks to facilitate debugging.
1744        Info are printed to stdout.
1746        Args:
1747            status: If not None, only the tasks with this status are selected
1748            nids: optional list of node identifiers used to filter the tasks.
1749            stream: File-like object. Default: sys.stdout
1750        """
1751        nrows, ncols = get_terminal_size()
1753        # Test for scheduler exceptions first.
1754        sched_excfile = os.path.join(self.workdir, "_exceptions")
1755        if os.path.exists(sched_excfile):
1756            with open(sched_excfile, "r") as fh:
1757                cprint("Found exceptions raised by the scheduler", "red", file=stream)
1758                cprint(fh.read(), color="red", file=stream)
1759                return
1761        if status is not None:
1762            tasks = list(self.iflat_tasks(status=status, nids=nids))
1763        else:
1764            errors = list(self.iflat_tasks(status=self.S_ERROR, nids=nids))
1765            qcriticals = list(self.iflat_tasks(status=self.S_QCRITICAL, nids=nids))
1766            abicriticals = list(self.iflat_tasks(status=self.S_ABICRITICAL, nids=nids))
1767            tasks = errors + qcriticals + abicriticals
1769        # For each task selected:
1770        #     1) Check the error files of the task. If not empty, print the content to stdout and we are done.
1771        #     2) If error files are empty, look at the master log file for possible errors
1772        #     3) If also this check failes, scan all the process log files.
1773        #        TODO: This check is not needed if we introduce a new __abinit_error__ file
1774        #        that is created by the first MPI process that invokes MPI abort!
1775        #
1776        ntasks = 0
1777        for task in tasks:
1778            print(make_banner(str(task), width=ncols, mark="="), file=stream)
1779            ntasks += 1
1781            #  Start with error files.
1782            for efname in ["qerr_file", "stderr_file",]:
1783                err_file = getattr(task, efname)
1784                if err_file.exists:
1785                    s = err_file.read()
1786                    if not s: continue
1787                    print(make_banner(str(err_file), width=ncols, mark="="), file=stream)
1788                    cprint(s, color="red", file=stream)
1789                    #count += 1
1791            # Check main log file.
1792            try:
1793                report = task.get_event_report()
1794                if report and report.num_errors:
1795                    print(make_banner(os.path.basename(report.filename), width=ncols, mark="="), file=stream)
1796                    s = "\n".join(str(e) for e in report.errors)
1797                else:
1798                    s = None
1799            except Exception as exc:
1800                s = str(exc)
1802            count = 0 # count > 0 means we found some useful info that could explain the failures.
1803            if s is not None:
1804                cprint(s, color="red", file=stream)
1805                count += 1
1807            if not count:
1808                # Inspect all log files produced by the other nodes.
1809                log_files = task.tmpdir.list_filepaths(wildcard="*LOG_*")
1810                if not log_files:
1811                    cprint("No *LOG_* file in tmpdir. This usually happens if you are running with many CPUs",
1812                           color="magenta", file=stream)
1814                for log_file in log_files:
1815                    try:
1816                        report = EventsParser().parse(log_file)
1817                        if report.errors:
1818                            print(report, file=stream)
1819                            count += 1
1820                            break
1821                    except Exception as exc:
1822                        cprint(str(exc), color="red", file=stream)
1823                        count += 1
1824                        break
1826            if not count:
1827                cprint("Houston, we could not find any error message that can explain the problem",
1828                        color="magenta", file=stream)
1830        print("Number of tasks analyzed: %d" % ntasks, file=stream)
1832    def cancel(self, nids=None):
1833        """
1834        Cancel all the tasks that are in the queue.
1835        nids is an optional list of node identifiers used to filter the tasks.
1837        Returns:
1838            Number of jobs cancelled, negative value if error
1839        """
1840        if self.has_chrooted:
1841            # TODO: Use paramiko to kill the job?
1842            warnings.warn("Cannot cancel the flow via sshfs!")
1843            return -1
1845        # If we are running with the scheduler, we must send a SIGKILL signal.
1846        if os.path.exists(self.pid_file):
1847            cprint("Found scheduler attached to this flow.", "yellow")
1848            cprint("Sending SIGKILL to the scheduler before cancelling the tasks!", "yellow")
1850            with open(self.pid_file, "rt") as fh:
1851                pid = int(fh.readline())
1853            retcode = os.system("kill -9 %d" % pid)
1854            self.history.info("Sent SIGKILL to the scheduler, retcode: %s" % retcode)
1855            try:
1856                os.remove(self.pid_file)
1857            except IOError:
1858                pass
1860        num_cancelled = 0
1861        for task in self.iflat_tasks(nids=nids):
1862            num_cancelled += task.cancel()
1864        return num_cancelled
1866    def get_njobs_in_queue(self, username=None):
1867        """
1868        returns the number of jobs in the queue, None when the number of jobs cannot be determined.
1870        Args:
1871            username: (str) the username of the jobs to count (default is to autodetect)
1872        """
1873        return self.manager.qadapter.get_njobs_in_queue(username=username)
1875    def rmtree(self, ignore_errors=False, onerror=None):
1876        """Remove workdir (same API as shutil.rmtree)."""
1877        if not os.path.exists(self.workdir): return
1878        shutil.rmtree(self.workdir, ignore_errors=ignore_errors, onerror=onerror)
1880    def rm_and_build(self):
1881        """Remove the workdir and rebuild the flow."""
1882        self.rmtree()
1883        self.build()
1885    def build(self, *args, **kwargs):
1886        """Make directories and files of the `Flow`."""
1887        # Allocate here if not done yet!
1888        if not self.allocated: self.allocate()
1890        self.indir.makedirs()
1891        self.outdir.makedirs()
1892        self.tmpdir.makedirs()
1894        # Check the nodeid file in workdir
1895        nodeid_path = os.path.join(self.workdir, ".nodeid")
1897        if os.path.exists(nodeid_path):
1898            with open(nodeid_path, "rt") as fh:
1899                node_id = int(fh.read())
1901            if self.node_id != node_id:
1902                msg = ("\nFound node_id %s in file:\n\n  %s\n\nwhile the node_id of the present flow is %d.\n"
1903                       "This means that you are trying to build a new flow in a directory already used by another flow.\n"
1904                       "Possible solutions:\n"
1905                       "   1) Change the workdir of the new flow.\n"
1906                       "   2) remove the old directory either with `rm -r` or by calling the method flow.rmtree()\n"
1907                       % (node_id, nodeid_path, self.node_id))
1908                raise RuntimeError(msg)
1910        else:
1911            with open(nodeid_path, "wt") as fh:
1912                fh.write(str(self.node_id))
1914        if self.pyfile and os.path.isfile(self.pyfile):
1915            shutil.copy(self.pyfile, self.workdir)
1917        for work in self:
1918            work.build(*args, **kwargs)
1920    def build_and_pickle_dump(self, abivalidate=False):
1921        """
1922        Build dirs and file of the `Flow` and save the object in pickle format.
1923        Returns 0 if success
1925        Args:
1926            abivalidate: If True, all the input files are validate by calling
1927                the abinit parser. If the validation fails, ValueError is raise.
1928        """
1929        self.build()
1930        if not abivalidate: return self.pickle_dump()
1932        # Validation with Abinit.
1933        isok, errors = self.abivalidate_inputs()
1934        if isok: return self.pickle_dump()
1935        errlines = []
1936        for i, e in enumerate(errors):
1937            errlines.append("[%d] %s" % (i, e))
1938        raise ValueError("\n".join(errlines))
1940    @check_spectator
1941    def pickle_dump(self):
1942        """
1943        Save the status of the object in pickle format.
1944        Returns 0 if success
1945        """
1946        if self.has_chrooted:
1947            warnings.warn("Cannot pickle_dump since we have chrooted from %s" % self.has_chrooted)
1948            return -1
1950        #if self.in_spectator_mode:
1951        #    warnings.warn("Cannot pickle_dump since flow is in_spectator_mode")
1952        #    return -2
1954        protocol = self.pickle_protocol
1956        # Atomic transaction with FileLock.
1957        with FileLock(self.pickle_file):
1958            with AtomicFile(self.pickle_file, mode="wb") as fh:
1959                pmg_pickle_dump(self, fh, protocol=protocol)
1961        return 0
1963    def pickle_dumps(self, protocol=None):
1964        """
1965        Return a string with the pickle representation.
1966        `protocol` selects the pickle protocol. self.pickle_protocol is used if `protocol` is None
1967        """
1968        strio = StringIO()
1969        pmg_pickle_dump(self, strio,
1970                        protocol=self.pickle_protocol if protocol is None
1971                        else protocol)
1972        return strio.getvalue()
1974    def register_task(self, input, deps=None, manager=None, task_class=None, append=False):
1975        """
1976        Utility function that generates a `Work` made of a single task
1978        Args:
1979            input: |AbinitInput|
1980            deps: List of :class:`Dependency` objects specifying the dependency of this node.
1981                  An empy list of deps implies that this node has no dependencies.
1982            manager: The |TaskManager| responsible for the submission of the task.
1983                     If manager is None, we use the |TaskManager| specified during the creation of the work.
1984            task_class: Task subclass to instantiate. Default: |AbinitTask|
1985            append: If true, the task is added to the last work (a new Work is created if flow is empty)
1987        Returns:
1988            The generated |Work| for the task, work[0] is the actual task.
1989        """
1990        # append True is much easier to use. In principle should be the default behaviour
1991        # but this would break the previous API so ...
1992        if not append:
1993            work = Work(manager=manager)
1994        else:
1995            if not self.works:
1996                work = Work(manager=manager)
1997                append = False
1998            else:
1999                work = self.works[-1]
2001        task = work.register(input, deps=deps, task_class=task_class)
2002        if not append: self.register_work(work)
2004        return work
2006    def register_work(self, work, deps=None, manager=None, workdir=None):
2007        """
2008        Register a new |Work| and add it to the internal list, taking into account possible dependencies.
2010        Args:
2011            work: |Work| object.
2012            deps: List of :class:`Dependency` objects specifying the dependency of this node.
2013                  An empy list of deps implies that this node has no dependencies.
2014            manager: The |TaskManager| responsible for the submission of the task.
2015                     If manager is None, we use the `TaskManager` specified during the creation of the work.
2016            workdir: The name of the directory used for the |Work|.
2018        Returns:
2019            The registered |Work|.
2020        """
2021        if getattr(self, "workdir", None) is not None:
2022            # The flow has a directory, build the name of the directory of the work.
2023            work_workdir = None
2024            if workdir is None:
2025                work_workdir = os.path.join(self.workdir, "w" + str(len(self)))
2026            else:
2027                work_workdir = os.path.join(self.workdir, os.path.basename(workdir))
2029            work.set_workdir(work_workdir)
2031        if manager is not None:
2032            work.set_manager(manager)
2034        self.works.append(work)
2036        if deps:
2037            deps = [Dependency(node, exts) for node, exts in deps.items()]
2038            work.add_deps(deps)
2040        return work
2042    def register_work_from_cbk(self, cbk_name, cbk_data, deps, work_class, manager=None):
2043        """
2044        Registers a callback function that will generate the :class:`Task` of the :class:`Work`.
2046        Args:
2047            cbk_name: Name of the callback function (must be a bound method of self)
2048            cbk_data: Additional data passed to the callback function.
2049            deps: List of :class:`Dependency` objects specifying the dependency of the work.
2050            work_class: |Work| class to instantiate.
2051            manager: The |TaskManager| responsible for the submission of the task.
2052                    If manager is None, we use the `TaskManager` specified during the creation of the |Flow|.
2054        Returns:
2055            The |Work| that will be finalized by the callback.
2056        """
2057        # TODO: pass a Work factory instead of a class
2058        # Directory of the Work.
2059        work_workdir = os.path.join(self.workdir, "w" + str(len(self)))
2061        # Create an empty work and register the callback
2062        work = work_class(workdir=work_workdir, manager=manager)
2064        self._works.append(work)
2066        deps = [Dependency(node, exts) for node, exts in deps.items()]
2067        if not deps:
2068            raise ValueError("A callback must have deps!")
2070        work.add_deps(deps)
2072        # Wrap the callable in a Callback object and save
2073        # useful info such as the index of the work and the callback data.
2074        cbk = FlowCallback(cbk_name, self, deps=deps, cbk_data=cbk_data)
2075        self._callbacks.append(cbk)
2077        return work
2079    @property
2080    def allocated(self):
2081        """Numer of allocations. Set by `allocate`."""
2082        try:
2083            return self._allocated
2084        except AttributeError:
2085            return 0
2087    def allocate(self, workdir=None, use_smartio=False):
2088        """
2089        Allocate the `Flow` i.e. assign the `workdir` and (optionally)
2090        the |TaskManager| to the different tasks in the Flow.
2092        Args:
2093            workdir: Working directory of the flow. Must be specified here
2094                if we haven't initialized the workdir in the __init__.
2096        Return:
2097            self
2098        """
2099        if workdir is not None:
2100            # We set the workdir of the flow here
2101            self.set_workdir(workdir)
2102            for i, work in enumerate(self):
2103                work.set_workdir(os.path.join(self.workdir, "w" + str(i)))
2105        if not hasattr(self, "workdir"):
2106            raise RuntimeError("You must call flow.allocate(workdir) if the workdir is not passed to __init__")
2108        for work in self:
2109            # Each work has a reference to its flow.
2110            work.allocate(manager=self.manager)
2111            work.set_flow(self)
2112            # Each task has a reference to its work.
2113            for task in work:
2114                task.set_work(work)
2116        self.check_dependencies()
2118        if not hasattr(self, "_allocated"): self._allocated = 0
2119        self._allocated += 1
2121        if use_smartio:
2122            self.use_smartio()
2124        return self
2126    def use_smartio(self):
2127        """
2128        This function should be called when the entire `Flow` has been built.
2129        It tries to reduce the pressure on the hard disk by using Abinit smart-io
2130        capabilities for those files that are not needed by other nodes.
2131        Smart-io means that big files (e.g. WFK) are written only if the calculation
2132        is unconverged so that we can restart from it. No output is produced if
2133        convergence is achieved.
2135        Return: self
2136        """
2137        if not self.allocated:
2138            #raise RuntimeError("You must call flow.allocate() before invoking flow.use_smartio()")
2139            self.allocate()
2141        for task in self.iflat_tasks():
2142            children = task.get_children()
2143            if not children:
2144                # Change the input so that output files are produced
2145                # only if the calculation is not converged.
2146                task.history.info("Will disable IO for task")
2147                task.set_vars(prtwf=-1, prtden=0) # TODO: prt1wf=-1,
2148            else:
2149                must_produce_abiexts = []
2150                for child in children:
2151                    # Get the list of dependencies. Find that task
2152                    for d in child.deps:
2153                        must_produce_abiexts.extend(d.exts)
2155                must_produce_abiexts = set(must_produce_abiexts)
2156                #print("must_produce_abiexts", must_produce_abiexts)
2158                # Variables supporting smart-io.
2159                smart_prtvars = {
2160                    "prtwf": "WFK",
2161                }
2163                # Set the variable to -1 to disable the output
2164                for varname, abiext in smart_prtvars.items():
2165                    if abiext not in must_produce_abiexts:
2166                        print("%s: setting %s to -1" % (task, varname))
2167                        task.set_vars({varname: -1})
2169        return self
2171    def show_dependencies(self, stream=sys.stdout):
2172        """
2173        Writes to the given stream the ASCII representation of the dependency tree.
2174        """
2175        def child_iter(node):
2176            return [d.node for d in node.deps]
2178        def text_str(node):
2179            return colored(str(node), color=node.status.color_opts["color"])
2181        for task in self.iflat_tasks():
2182            print(draw_tree(task, child_iter, text_str), file=stream)
2184    def on_all_ok(self):
2185        """
2186        This method is called when all the works in the flow have reached S_OK.
2187        This method shall return True if the calculation is completed or
2188        False if the execution should continue due to side-effects such as adding a new work to the flow.
2190        This methods allows subclasses to implement customized logic such as extending the flow by adding new works.
2191        The flow has an internal counter: on_all_ok_num_calls
2192        that shall be incremented by client code when subclassing this method.
2193        This counter can be used to decide if futher actions are needed or not.
2195        An example of flow that adds a new work (only once) when all_ok is reached for the first time:
2197        def on_all_ok(self):
2198            if self.on_all_ok_num_calls > 0: return True
2199            self.on_all_ok_num_calls += 1
2201            `implement_logic_to_create_new_work`
2203            self.register_work(work)
2204            self.allocate()
2205            self.build_and_pickle_dump()
2207            return False # The scheduler will keep on running the flow.
2208        """
2209        return True
2211    def on_dep_ok(self, signal, sender):
2212        # TODO
2213        # Replace this callback with dynamic dispatch
2214        # on_all_S_OK for work
2215        # on_S_OK for task
2216        self.history.info("on_dep_ok with sender %s, signal %s" % (str(sender), signal))
2218        for i, cbk in enumerate(self._callbacks):
2219            if not cbk.handle_sender(sender):
2220                self.history.info("%s does not handle sender %s" % (cbk, sender))
2221                continue
2223            if not cbk.can_execute():
2224                self.history.info("Cannot execute %s" % cbk)
2225                continue
2227            # Execute the callback and disable it
2228            self.history.info("flow in on_dep_ok: about to execute callback %s" % str(cbk))
2229            cbk()
2230            cbk.disable()
2232            # Update the database.
2233            self.pickle_dump()
2235    @check_spectator
2236    def finalize(self):
2237        """
2238        This method is called when the flow is completed. Return 0 if success
2239        """
2240        self.history.info("Calling flow.finalize.")
2241        if self.finalized:
2242            self.history.warning("Calling finalize on an already finalized flow. Returning 1")
2243            return 1
2245        self.finalized = True
2247        if self.has_db:
2248            self.history.info("Saving results in database.")
2249            try:
2250                self.flow.db_insert()
2251                self.finalized = True
2252            except Exception:
2253                self.history.critical("MongoDb insertion failed.")
2254                return 2
2256        # Here we remove the big output files if we have the garbage collector
2257        # and the policy is set to "flow."
2258        if self.gc is not None and self.gc.policy == "flow":
2259            self.history.info("gc.policy set to flow. Will clean task output files.")
2260            for task in self.iflat_tasks():
2261                task.clean_output_files()
2263        return 0
2265    def set_garbage_collector(self, exts=None, policy="task"):
2266        """
2267        Enable the garbage collector that will remove the big output files that are not needed.
2269        Args:
2270            exts: string or list with the Abinit file extensions to be removed. A default is
2271                provided if exts is None
2272            policy: Either `flow` or `task`. If policy is set to 'task', we remove the output
2273                files as soon as the task reaches S_OK. If 'flow', the files are removed
2274                only when the flow is finalized. This option should be used when we are dealing
2275                with a dynamic flow with callbacks generating other tasks since a |Task|
2276                might not be aware of its children when it reached S_OK.
2277        """
2278        assert policy in ("task", "flow")
2279        exts = list_strings(exts) if exts is not None else ("WFK", "SUS", "SCR", "BSR", "BSC")
2281        gc = GarbageCollector(exts=set(exts), policy=policy)
2283        self.set_gc(gc)
2284        for work in self:
2285            #work.set_gc(gc) # TODO Add support for Works and flow policy
2286            for task in work:
2287                task.set_gc(gc)
2289    def connect_signals(self):
2290        """
2291        Connect the signals within the `Flow`.
2292        The `Flow` is responsible for catching the important signals raised from its works.
2293        """
2294        # Connect the signals inside each Work.
2295        for work in self:
2296            work.connect_signals()
2298        # Observe the nodes that must reach S_OK in order to call the callbacks.
2299        for cbk in self._callbacks:
2300            #cbk.enable()
2301            for dep in cbk.deps:
2302                self.history.info("Connecting %s \nwith sender %s, signal %s" % (str(cbk), dep.node, dep.node.S_OK))
2303                dispatcher.connect(self.on_dep_ok, signal=dep.node.S_OK, sender=dep.node, weak=False)
2305        # Associate to each signal the callback _on_signal
2306        # (bound method of the node that will be called by `Flow`
2307        # Each node will set its attribute _done_signal to True to tell
2308        # the flow that this callback should be disabled.
2310        # Register the callbacks for the Work.
2311        #for work in self:
2312        #    slot = self._sig_slots[work]
2313        #    for signal in S_ALL:
2314        #        done_signal = getattr(work, "_done_ " + signal, False)
2315        #        if not done_sig:
2316        #            cbk_name = "_on_" + str(signal)
2317        #            cbk = getattr(work, cbk_name, None)
2318        #            if cbk is None: continue
2319        #            slot[work][signal].append(cbk)
2320        #            print("connecting %s\nwith sender %s, signal %s" % (str(cbk), dep.node, dep.node.S_OK))
2321        #            dispatcher.connect(self.on_dep_ok, signal=signal, sender=dep.node, weak=False)
2323        # Register the callbacks for the Tasks.
2324        #self.show_receivers()
2326    def disconnect_signals(self):
2327        """Disable the signals within the `Flow`."""
2328        # Disconnect the signals inside each Work.
2329        for work in self:
2330            work.disconnect_signals()
2332        # Disable callbacks.
2333        for cbk in self._callbacks:
2334            cbk.disable()
2336    def show_receivers(self, sender=None, signal=None):
2337        sender = sender if sender is not None else dispatcher.Any
2338        signal = signal if signal is not None else dispatcher.Any
2339        print("*** live receivers ***")
2340        for rec in dispatcher.liveReceivers(dispatcher.getReceivers(sender, signal)):
2341            print("receiver -->", rec)
2342        print("*** end live receivers ***")
2344    def set_spectator_mode(self, mode=True):
2345        """
2346        When the flow is in spectator_mode, we have to disable signals, pickle dump and possible callbacks
2347        A spectator can still operate on the flow but the new status of the flow won't be saved in
2348        the pickle file. Usually the flow is in spectator mode when we are already running it via
2349        the scheduler or other means and we should not interfere with its evolution.
2350        This is the reason why signals and callbacks must be disabled.
2351        Unfortunately preventing client-code from calling methods with side-effects when
2352        the flow is in spectator mode is not easy (e.g. flow.cancel will cancel the tasks submitted to the
2353        queue and the flow used by the scheduler won't see this change!
2354        """
2355        # Set the flags of all the nodes in the flow.
2356        mode = bool(mode)
2357        self.in_spectator_mode = mode
2358        for node in self.iflat_nodes():
2359            node.in_spectator_mode = mode
2361        # connect/disconnect signals depending on mode.
2362        if not mode:
2363            self.connect_signals()
2364        else:
2365            self.disconnect_signals()
2367    #def get_results(self, **kwargs)
2369    def rapidfire(self, check_status=True, max_nlaunch=-1, max_loops=1, sleep_time=5, **kwargs):
2370        """
2371        Use :class:`PyLauncher` to submits tasks in rapidfire mode.
2372        kwargs contains the options passed to the launcher.
2374        Args:
2375            check_status:
2376            max_nlaunch: Maximum number of launches. default: no limit.
2377            max_loops: Maximum number of loops
2378            sleep_time: seconds to sleep between rapidfire loop iterations
2380        Return: Number of tasks submitted.
2381        """
2382        self.check_pid_file()
2383        self.set_spectator_mode(False)
2384        if check_status: self.check_status()
2385        from .launcher import PyLauncher
2386        return PyLauncher(self, **kwargs).rapidfire(max_nlaunch=max_nlaunch, max_loops=max_loops, sleep_time=sleep_time)
2388    def single_shot(self, check_status=True, **kwargs):
2389        """
2390        Use :class:`PyLauncher` to submits one task.
2391        kwargs contains the options passed to the launcher.
2393        Return: Number of tasks submitted.
2394        """
2395        self.check_pid_file()
2396        self.set_spectator_mode(False)
2397        if check_status: self.check_status()
2398        from .launcher import PyLauncher
2399        return PyLauncher(self, **kwargs).single_shot()
2401    def make_scheduler(self, **kwargs):
2402        """
2403        Build and return a :class:`PyFlowScheduler` to run the flow.
2405        Args:
2406            kwargs: if empty we use the user configuration file.
2407                    if `filepath` in kwargs we init the scheduler from filepath.
2408                    else pass kwargs to :class:`PyFlowScheduler` __init__ method.
2409        """
2410        from .launcher import PyFlowScheduler
2411        if not kwargs:
2412            # User config if kwargs is empty
2413            sched = PyFlowScheduler.from_user_config()
2414        else:
2415            # Use from_file if filepath if present, else call __init__
2416            filepath = kwargs.pop("filepath", None)
2417            if filepath is not None:
2418                assert not kwargs
2419                sched = PyFlowScheduler.from_file(filepath)
2420            else:
2421                sched = PyFlowScheduler(**kwargs)
2423        sched.add_flow(self)
2424        return sched
2426    def batch(self, timelimit=None):
2427        """
2428        Run the flow in batch mode, return exit status of the job script.
2429        Requires a manager.yml file and a batch_adapter adapter.
2431        Args:
2432            timelimit: Time limit (int with seconds or string with time given with the slurm convention:
2433            "days-hours:minutes:seconds"). If timelimit is None, the default value specified in the
2434            `batch_adapter` entry of `manager.yml` is used.
2435        """
2436        from .launcher import BatchLauncher
2437        # Create a batch dir from the flow.workdir.
2438        prev_dir = os.path.join(*self.workdir.split(os.path.sep)[:-1])
2439        prev_dir = os.path.join(os.path.sep, prev_dir)
2440        workdir = os.path.join(prev_dir, os.path.basename(self.workdir) + "_batch")
2442        return BatchLauncher(workdir=workdir, flows=self).submit(timelimit=timelimit)
2444    def make_light_tarfile(self, name=None):
2445        """Lightweight tarball file. Mainly used for debugging. Return the name of the tarball file."""
2446        name = os.path.basename(self.workdir) + "-light.tar.gz" if name is None else name
2447        return self.make_tarfile(name=name, exclude_dirs=["outdata", "indata", "tmpdata"])
2449    def make_tarfile(self, name=None, max_filesize=None, exclude_exts=None, exclude_dirs=None, verbose=0, **kwargs):
2450        """
2451        Create a tarball file.
2453        Args:
2454            name: Name of the tarball file. Set to os.path.basename(`flow.workdir`) + "tar.gz"` if name is None.
2455            max_filesize (int or string with unit): a file is included in the tar file if its size <= max_filesize
2456                Can be specified in bytes e.g. `max_files=1024` or with a string with unit e.g. `max_filesize="1 Mb"`.
2457                No check is done if max_filesize is None.
2458            exclude_exts: List of file extensions to be excluded from the tar file.
2459            exclude_dirs: List of directory basenames to be excluded.
2460            verbose (int): Verbosity level.
2461            kwargs: keyword arguments passed to the :class:`TarFile` constructor.
2463        Returns: The name of the tarfile.
2464        """
2465        def any2bytes(s):
2466            """Convert string or number to memory in bytes."""
2467            if is_string(s):
2468                return int(Memory.from_string(s).to("b"))
2469            else:
2470                return int(s)
2472        if max_filesize is not None:
2473            max_filesize = any2bytes(max_filesize)
2475        if exclude_exts:
2476            # Add/remove ".nc" so that we can simply pass "GSR" instead of "GSR.nc"
2477            # Moreover this trick allows one to treat WFK.nc and WFK file on the same footing.
2478            exts = []
2479            for e in list_strings(exclude_exts):
2480                exts.append(e)
2481                if e.endswith(".nc"):
2482                    exts.append(e.replace(".nc", ""))
2483                else:
2484                    exts.append(e + ".nc")
2485            exclude_exts = exts
2487        def filter(tarinfo):
2488            """
2489            Function that takes a TarInfo object argument and returns the changed TarInfo object.
2490            If it instead returns None the TarInfo object will be excluded from the archive.
2491            """
2492            # Skip links.
2493            if tarinfo.issym() or tarinfo.islnk():
2494                if verbose: print("Excluding link: %s" % tarinfo.name)
2495                return None
2497            # Check size in bytes
2498            if max_filesize is not None and tarinfo.size > max_filesize:
2499                if verbose: print("Excluding %s due to max_filesize" % tarinfo.name)
2500                return None
2502            # Filter filenames.
2503            if exclude_exts and any(tarinfo.name.endswith(ext) for ext in exclude_exts):
2504                if verbose: print("Excluding %s due to extension" % tarinfo.name)
2505                return None
2507            # Exlude directories (use dir basenames).
2508            if exclude_dirs and any(dir_name in exclude_dirs for dir_name in tarinfo.name.split(os.path.sep)):
2509                if verbose: print("Excluding %s due to exclude_dirs" % tarinfo.name)
2510                return None
2512            return tarinfo
2514        back = os.getcwd()
2515        os.chdir(os.path.join(self.workdir, ".."))
2517        import tarfile
2518        name = os.path.basename(self.workdir) + ".tar.gz" if name is None else name
2519        with tarfile.open(name=name, mode='w:gz', **kwargs) as tar:
2520            tar.add(os.path.basename(self.workdir), arcname=None, recursive=True, filter=filter)
2522            # Add the script used to generate the flow.
2523            if self.pyfile is not None and os.path.exists(self.pyfile):
2524                tar.add(self.pyfile)
2526        os.chdir(back)
2527        return name
2529    def get_graphviz(self, engine="automatic", graph_attr=None, node_attr=None, edge_attr=None):
2530        """
2531        Generate flow graph in the DOT language.
2533        Args:
2534            engine: Layout command used. ['dot', 'neato', 'twopi', 'circo', 'fdp', 'sfdp', 'patchwork', 'osage']
2535            graph_attr: Mapping of (attribute, value) pairs for the graph.
2536            node_attr: Mapping of (attribute, value) pairs set for all nodes.
2537            edge_attr: Mapping of (attribute, value) pairs set for all edges.
2539        Returns: graphviz.Digraph <https://graphviz.readthedocs.io/en/stable/api.html#digraph>
2540        """
2541        self.allocate()
2543        from graphviz import Digraph
2544        fg = Digraph("flow", #filename="flow_%s.gv" % os.path.basename(self.relworkdir),
2545            engine="fdp" if engine == "automatic" else engine)
2547        # Set graph attributes. https://www.graphviz.org/doc/info/
2548        fg.attr(label=repr(self))
2549        fg.attr(rankdir="LR", pagedir="BL")
2550        fg.node_attr.update(color='lightblue2', style='filled')
2552        # Add input attributes.
2553        if graph_attr is not None: fg.graph_attr.update(**graph_attr)
2554        if node_attr is not None: fg.node_attr.update(**node_attr)
2555        if edge_attr is not None: fg.edge_attr.update(**edge_attr)
2557        def node_kwargs(node):
2558            return dict(
2559                #shape="circle",
2560                color=node.color_hex,
2561                fontsize="8.0",
2562                label=(str(node) if not hasattr(node, "pos_str") else
2563                    node.pos_str + "\n" + node.__class__.__name__),
2564            )
2566        edge_kwargs = dict(arrowType="vee", style="solid")
2567        cluster_kwargs = dict(rankdir="LR", pagedir="BL", style="rounded", bgcolor="azure2")
2569        for work in self:
2570            # Build cluster with tasks.
2571            cluster_name = "cluster%s" % work.name
2572            with fg.subgraph(name=cluster_name) as wg:
2573                wg.attr(**cluster_kwargs)
2574                wg.attr(label="%s (%s)" % (work.__class__.__name__, work.name))
2575                for task in work:
2576                    wg.node(task.name, **node_kwargs(task))
2577                    # Connect children to task.
2578                    for child in task.get_children():
2579                        # Find file extensions required by this task
2580                        i = [dep.node for dep in child.deps].index(task)
2581                        edge_label = "+".join(child.deps[i].exts)
2582                        fg.edge(task.name, child.name, label=edge_label, color=task.color_hex,
2583                                **edge_kwargs)
2585        # Treat the case in which we have a work producing output for other tasks.
2586        for work in self:
2587            children = work.get_children()
2588            if not children: continue
2589            cluster_name = "cluster%s" % work.name
2590            seen = set()
2591            for child in children:
2592                # This is not needed, too much confusing
2593                #fg.edge(cluster_name, child.name, color=work.color_hex, **edge_kwargs)
2594                # Find file extensions required by work
2595                i = [dep.node for dep in child.deps].index(work)
2596                for ext in child.deps[i].exts:
2597                    out = "%s (%s)" % (ext, work.name)
2598                    fg.node(out)
2599                    fg.edge(out, child.name, **edge_kwargs)
2600                    key = (cluster_name, out)
2601                    if key not in seen:
2602                        seen.add(key)
2603                        fg.edge(cluster_name, out, color=work.color_hex, **edge_kwargs)
2605        # Treat the case in which we have a task that depends on external files.
2606        seen = set()
2607        for task in self.iflat_tasks():
2608            #print(task.get_parents())
2609            for node in (p for p in task.get_parents() if p.is_file):
2610                #print("parent file node", node)
2611                #infile = "%s (%s)" % (ext, work.name)
2612                infile = node.filepath
2613                if infile not in seen:
2614                    seen.add(infile)
2615                    fg.node(infile, **node_kwargs(node))
2617                fg.edge(infile, task.name, color=node.color_hex, **edge_kwargs)
2619        return fg
2621    @add_fig_kwargs
2622    def graphviz_imshow(self, ax=None, figsize=None, dpi=300, fmt="png", **kwargs):
2623        """
2624        Generate flow graph in the DOT language and plot it with matplotlib.
2626        Args:
2627            ax: |matplotlib-Axes| or None if a new figure should be created.
2628            figsize: matplotlib figure size (None to use default)
2629            dpi: DPI value.
2630            fmt: Select format for output image
2632        Return: |matplotlib-Figure|
2633        """
2634        graph = self.get_graphviz(**kwargs)
2635        graph.format = fmt
2636        graph.attr(dpi=str(dpi))
2637        _, tmpname = tempfile.mkstemp()
2638        path = graph.render(tmpname, view=False, cleanup=True)
2639        ax, fig, _ = get_ax_fig_plt(ax=ax, figsize=figsize, dpi=dpi)
2640        import matplotlib.image as mpimg
2641        ax.imshow(mpimg.imread(path, format="png")) #, interpolation="none")
2642        ax.axis("off")
2644        return fig
2646    @add_fig_kwargs
2647    def plot_networkx(self, mode="network", with_edge_labels=False, ax=None, arrows=False,
2648                      node_size="num_cores", node_label="name_class", layout_type="spring", **kwargs):
2649        """
2650        Use networkx to draw the flow with the connections among the nodes and
2651        the status of the tasks.
2653        Args:
2654            mode: `networkx` to show connections, `status` to group tasks by status.
2655            with_edge_labels: True to draw edge labels.
2656            ax: |matplotlib-Axes| or None if a new figure should be created.
2657            arrows: if True draw arrowheads.
2658            node_size: By default, the size of the node is proportional to the number of cores used.
2659            node_label: By default, the task class is used to label node.
2660            layout_type: Get positions for all nodes using `layout_type`. e.g. pos = nx.spring_layout(g)
2662        .. warning::
2664            Requires networkx package.
2665        """
2666        self.allocate()
2668        # Build the graph
2669        import networkx as nx
2670        g = nx.Graph() if not arrows else nx.DiGraph()
2671        edge_labels = {}
2672        for task in self.iflat_tasks():
2673            g.add_node(task, name=task.name)
2674            for child in task.get_children():
2675                g.add_edge(task, child)
2676                # TODO: Add getters! What about locked nodes!
2677                i = [dep.node for dep in child.deps].index(task)
2678                edge_labels[(task, child)] = " ".join(child.deps[i].exts)
2680            filedeps = [d for d in task.deps if d.node.is_file]
2681            for d in filedeps:
2682                #print(d.node, d.exts)
2683                g.add_node(d.node, name="%s (%s)" % (d.node.basename, d.node.node_id))
2684                g.add_edge(d.node, task)
2685                edge_labels[(d.node, task)] = "+".join(d.exts)
2687        # This part is needed if we have a work that produces output used by other nodes.
2688        for work in self:
2689            children = work.get_children()
2690            if not children:
2691                continue
2692            g.add_node(work, name=work.name)
2693            for task in work:
2694                g.add_edge(task, work)
2695                edge_labels[(task, work)] = "all_ok "
2696            for child in children:
2697                g.add_edge(work, child)
2698                i = [dep.node for dep in child.deps].index(work)
2699                edge_labels[(work, child)] = "+".join(child.deps[i].exts)
2701        # Get positions for all nodes using layout_type.
2702        # e.g. pos = nx.spring_layout(g)
2703        pos = getattr(nx, layout_type + "_layout")(g)
2705        # Select function used to compute the size of the node
2706        def make_node_size(node):
2707            if node.is_task:
2708                return 300 * node.manager.num_cores
2709            else:
2710                return 600
2712        # Function used to build the label
2713        def make_node_label(node):
2714            if node_label == "name_class":
2715                if node.is_file:
2716                    return "%s\n(%s)" % (node.basename, node.node_id)
2717                else:
2718                    return (node.pos_str + "\n" + node.__class__.__name__
2719                            if hasattr(node, "pos_str") else str(node))
2720            else:
2721                raise NotImplementedError("node_label: %s" % str(node_label))
2723        labels = {node: make_node_label(node) for node in g.nodes()}
2724        ax, fig, plt = get_ax_fig_plt(ax=ax)
2726        # Select plot type.
2727        if mode == "network":
2728            nx.draw_networkx(g, pos, labels=labels,
2729                             node_color=[node.color_rgb for node in g.nodes()],
2730                             node_size=[make_node_size(node) for node in g.nodes()],
2731                             width=1, style="dotted", with_labels=True, arrows=arrows, ax=ax)
2733            # Draw edge labels
2734            if with_edge_labels:
2735                nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels, ax=ax)
2737        elif mode == "status":
2738            # Group tasks by status (only tasks are show here).
2739            for status in self.ALL_STATUS:
2740                tasks = list(self.iflat_tasks(status=status))
2742                # Draw nodes (color is given by status)
2743                node_color = status.color_opts["color"]
2744                if node_color is None: node_color = "black"
2745                #print("num nodes %s with node_color %s" % (len(tasks), node_color))
2747                nx.draw_networkx_nodes(g, pos,
2748                                       nodelist=tasks,
2749                                       node_color=node_color,
2750                                       node_size=[make_node_size(task) for task in tasks],
2751                                       alpha=0.5, ax=ax
2752                                       #label=str(status),
2753                                       )
2754            # Draw edges.
2755            nx.draw_networkx_edges(g, pos, width=2.0, alpha=0.5, arrows=arrows, ax=ax) # edge_color='r')
2757            # Draw labels
2758            nx.draw_networkx_labels(g, pos, labels, font_size=12, ax=ax)
2760            # Draw edge labels
2761            if with_edge_labels:
2762                nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels, ax=ax)
2763                #label_pos=0.5, font_size=10, font_color='k', font_family='sans-serif', font_weight='normal',
2764                # alpha=1.0, bbox=None, ax=None, rotate=True, **kwds)
2766        else:
2767            raise ValueError("Unknown value for mode: %s" % str(mode))
2769        ax.axis("off")
2770        return fig
2772    def write_open_notebook(flow, foreground):
2773        """
2774        Generate an ipython notebook and open it in the browser.
2775        Return system exit code.
2776        """
2777        import nbformat
2778        nbf = nbformat.v4
2779        nb = nbf.new_notebook()
2781        nb.cells.extend([
2782            #nbf.new_markdown_cell("This is an auto-generated notebook for %s" % os.path.basename(pseudopath)),
2783            nbf.new_code_cell("""\
2784    import sys, os
2785    import numpy as np
2787    %matplotlib notebook
2788    from IPython.display import display
2790    # This to render pandas DataFrames with https://github.com/quantopian/qgrid
2791    #import qgrid
2792    #qgrid.nbinstall(overwrite=True)  # copies javascript dependencies to your /nbextensions folder
2794    from abipy import abilab
2796    # Tell AbiPy we are inside a notebook and use seaborn settings for plots.
2797    # See https://seaborn.pydata.org/generated/seaborn.set.html#seaborn.set
2798    abilab.enable_notebook(with_seaborn=True)
2799    """),
2801            nbf.new_code_cell("flow = abilab.Flow.pickle_load('%s')" % flow.workdir),
2802            nbf.new_code_cell("if flow.num_errored_tasks: flow.debug()"),
2803            nbf.new_code_cell("flow.check_status(show=True, verbose=0)"),
2804            nbf.new_code_cell("flow.show_dependencies()"),
2805            nbf.new_code_cell("flow.plot_networkx();"),
2806            nbf.new_code_cell("#flow.get_graphviz();"),
2807            nbf.new_code_cell("flow.show_inputs(nids=None, wslice=None)"),
2808            nbf.new_code_cell("flow.show_history()"),
2809            nbf.new_code_cell("flow.show_corrections()"),
2810            nbf.new_code_cell("flow.show_event_handlers()"),
2811            nbf.new_code_cell("flow.inspect(nids=None, wslice=None)"),
2812            nbf.new_code_cell("flow.show_abierrors()"),
2813            nbf.new_code_cell("flow.show_qouts()"),
2814        ])
2816        import tempfile, io
2817        _, nbpath = tempfile.mkstemp(suffix='.ipynb', text=True)
2819        with io.open(nbpath, 'wt', encoding="utf8") as fh:
2820            nbformat.write(nb, fh)
2822        from monty.os.path import which
2823        has_jupyterlab = which("jupyter-lab") is not None
2824        appname = "jupyter-lab" if has_jupyterlab else "jupyter notebook"
2825        if not has_jupyterlab:
2826            if which("jupyter") is None:
2827                raise RuntimeError("Cannot find jupyter in $PATH. Install it with `pip install`")
2829        appname = "jupyter-lab" if has_jupyterlab else "jupyter notebook"
2831        if foreground:
2832            return os.system("%s %s" % (appname, nbpath))
2833        else:
2834            fd, tmpname = tempfile.mkstemp(text=True)
2835            print(tmpname)
2836            cmd = "%s %s" % (appname, nbpath)
2837            print("Executing:", cmd, "\nstdout and stderr redirected to %s" % tmpname)
2838            import subprocess
2839            process = subprocess.Popen(cmd.split(), shell=False, stdout=fd, stderr=fd)
2840            cprint("pid: %s" % str(process.pid), "yellow")
2841            return process.returncode
2844class G0W0WithQptdmFlow(Flow):
2846    def __init__(self, workdir, scf_input, nscf_input, scr_input, sigma_inputs, manager=None):
2847        """
2848        Build a :class:`Flow` for one-shot G0W0 calculations.
2849        The computation of the q-points for the screening is parallelized with qptdm
2850        i.e. we run independent calculations for each q-point and then we merge the final results.
2852        Args:
2853            workdir: Working directory.
2854            scf_input: Input for the GS SCF run.
2855            nscf_input: Input for the NSCF run (band structure run).
2856            scr_input: Input for the SCR run.
2857            sigma_inputs: Input(s) for the SIGMA run(s).
2858            manager: |TaskManager| object used to submit the jobs. Initialized from manager.yml if manager is None.
2859        """
2860        super().__init__(workdir, manager=manager)
2862        # Register the first work (GS + NSCF calculation)
2863        bands_work = self.register_work(BandStructureWork(scf_input, nscf_input))
2865        # Register the callback that will be executed the work for the SCR with qptdm.
2866        scr_work = self.register_work_from_cbk(cbk_name="cbk_qptdm_workflow", cbk_data={"input": scr_input},
2867                                               deps={bands_work.nscf_task: "WFK"}, work_class=QptdmWork)
2869        # The last work contains a list of SIGMA tasks
2870        # that will use the data produced in the previous two works.
2871        if not isinstance(sigma_inputs, (list, tuple)):
2872            sigma_inputs = [sigma_inputs]
2874        sigma_work = Work()
2875        for sigma_input in sigma_inputs:
2876            sigma_work.register_sigma_task(sigma_input, deps={bands_work.nscf_task: "WFK", scr_work: "SCR"})
2877        self.register_work(sigma_work)
2879        self.allocate()
2881    def cbk_qptdm_workflow(self, cbk):
2882        """
2883        This callback is executed by the flow when bands_work.nscf_task reaches S_OK.
2885        It computes the list of q-points for the W(q,G,G'), creates nqpt tasks
2886        in the second work (QptdmWork), and connect the signals.
2887        """
2888        scr_input = cbk.data["input"]
2889        # Use the WFK file produced by the second
2890        # Task in the first Work (NSCF step).
2891        nscf_task = self[0][1]
2892        wfk_file = nscf_task.outdir.has_abiext("WFK")
2894        work = self[1]
2895        work.set_manager(self.manager)
2896        work.create_tasks(wfk_file, scr_input)
2897        work.add_deps(cbk.deps)
2899        work.set_flow(self)
2900        # Each task has a reference to its work.
2901        for task in work:
2902            task.set_work(work)
2903            # Add the garbage collector.
2904            if self.gc is not None: task.set_gc(self.gc)
2906        work.connect_signals()
2907        work.build()
2909        return work
2912class FlowCallbackError(Exception):
2913    """Exceptions raised by FlowCallback."""
2916class FlowCallback(object):
2917    """
2918    This object implements the callbacks executed by the |Flow| when
2919    particular conditions are fulfilled. See on_dep_ok method of |Flow|.
2921    .. note::
2923        I decided to implement callbacks via this object instead of a standard
2924        approach based on bound methods because:
2926            1) pickle (v<=3) does not support the pickling/unplickling of bound methods
2928            2) There's some extra logic and extra data needed for the proper functioning
2929               of a callback at the flow level and this object provides an easy-to-use interface.
2930    """
2931    Error = FlowCallbackError
2933    def __init__(self, func_name, flow, deps, cbk_data):
2934        """
2935        Args:
2936            func_name: String with the name of the callback to execute.
2937                       func_name must be a bound method of flow with signature:
2939                            func_name(self, cbk)
2941                       where self is the Flow instance and cbk is the callback
2942            flow: Reference to the |Flow|.
2943            deps: List of dependencies associated to the callback
2944                  The callback is executed when all dependencies reach S_OK.
2945            cbk_data: Dictionary with additional data that will be passed to the callback via self.
2946        """
2947        self.func_name = func_name
2948        self.flow = flow
2949        self.deps = deps
2950        self.data = cbk_data or {}
2951        self._disabled = False
2953    def __str__(self):
2954        return "%s: %s bound to %s" % (self.__class__.__name__, self.func_name, self.flow)
2956    def __call__(self):
2957        """Execute the callback."""
2958        if self.can_execute():
2959            # Get the bound method of the flow from func_name.
2960            # We use this trick because pickle (format <=3) does not support bound methods.
2961            try:
2962                func = getattr(self.flow, self.func_name)
2963            except AttributeError as exc:
2964                raise self.Error(str(exc))
2966            return func(self)
2968        else:
2969            raise self.Error("You tried to __call_ a callback that cannot be executed!")
2971    def can_execute(self):
2972        """True if we can execute the callback."""
2973        return not self._disabled and all(dep.status == dep.node.S_OK for dep in self.deps)
2975    def disable(self):
2976        """
2977        True if the callback has been disabled. This usually happens when the callback has been executed.
2978        """
2979        self._disabled = True
2981    def enable(self):
2982        """Enable the callback"""
2983        self._disabled = False
2985    def handle_sender(self, sender):
2986        """
2987        True if the callback is associated to the sender
2988        i.e. if the node who sent the signal appears in the
2989        dependencies of the callback.
2990        """
2991        return sender in [d.node for d in self.deps]
2994# Factory functions.
2995def bandstructure_flow(workdir, scf_input, nscf_input, dos_inputs=None, manager=None, flow_class=Flow, allocate=True):
2996    """
2997    Build a |Flow| for band structure calculations.
2999    Args:
3000        workdir: Working directory.
3001        scf_input: Input for the GS SCF run.
3002        nscf_input: Input for the NSCF run (band structure run).
3003        dos_inputs: Input(s) for the NSCF run (dos run).
3004        manager: |TaskManager| object used to submit the jobs. Initialized from manager.yml if manager is None.
3005        flow_class: Flow subclass
3006        allocate: True if the flow should be allocated before returning.
3008    Returns: |Flow| object
3009    """
3010    flow = flow_class(workdir, manager=manager)
3011    work = BandStructureWork(scf_input, nscf_input, dos_inputs=dos_inputs)
3012    flow.register_work(work)
3014    # Handy aliases
3015    flow.scf_task, flow.nscf_task, flow.dos_tasks = work.scf_task, work.nscf_task, work.dos_tasks
3017    if allocate: flow.allocate()
3018    return flow
3021def g0w0_flow(workdir, scf_input, nscf_input, scr_input, sigma_inputs, manager=None, flow_class=Flow, allocate=True):
3022    """
3023    Build a |Flow| for one-shot $G_0W_0$ calculations.
3025    Args:
3026        workdir: Working directory.
3027        scf_input: Input for the GS SCF run.
3028        nscf_input: Input for the NSCF run (band structure run).
3029        scr_input: Input for the SCR run.
3030        sigma_inputs: List of inputs for the SIGMA run.
3031        flow_class: Flow class
3032        manager: |TaskManager| object used to submit the jobs. Initialized from manager.yml if manager is None.
3033        allocate: True if the flow should be allocated before returning.
3035    Returns: |Flow| object
3036    """
3037    flow = flow_class(workdir, manager=manager)
3038    work = G0W0Work(scf_input, nscf_input, scr_input, sigma_inputs)
3039    flow.register_work(work)
3040    if allocate: flow.allocate()
3041    return flow
3044class PhononFlow(Flow):
3045    """
3046    This Flow provides a high-level interface to compute phonons with DFPT
3047    The flow consists of
3049    1) One workflow for the GS run.
3051    2) nqpt works for phonon calculations. Each work contains
3052       nirred tasks where nirred is the number of irreducible phonon perturbations
3053       for that particular q-point.
3055    .. note:
3057        For a much more flexible interface, use the DFPT works defined in works.py
3058        For instance, EPH calculations are much easier to implement by connecting a single
3059        work that computes all the q-points with the EPH tasks instead of using PhononFlow.
3060    """
3061    @classmethod
3062    def from_scf_input(cls, workdir, scf_input, ph_ngqpt, with_becs=True, manager=None, allocate=True):
3063        """
3064        Create a `PhononFlow` for phonon calculations from an `AbinitInput` defining a ground-state run.
3066        Args:
3067            workdir: Working directory of the flow.
3068            scf_input: |AbinitInput| object with the parameters for the GS-SCF run.
3069            ph_ngqpt: q-mesh for phonons. Must be a sub-mesh of the k-mesh used for
3070                electrons. e.g if ngkpt = (8, 8, 8). ph_ngqpt = (4, 4, 4) is a valid choice
3071                whereas ph_ngqpt = (3, 3, 3) is not!
3072            with_becs: True if Born effective charges are wanted.
3073            manager: |TaskManager| object. Read from `manager.yml` if None.
3074            allocate: True if the flow should be allocated before returning.
3076        Return:
3077            :class:`PhononFlow` object.
3078        """
3079        flow = cls(workdir, manager=manager)
3081        # Register the SCF task
3082        flow.register_scf_task(scf_input)
3083        scf_task = flow[0][0]
3085        # Make sure k-mesh and q-mesh are compatible.
3086        scf_ngkpt, ph_ngqpt = np.array(scf_input["ngkpt"]), np.array(ph_ngqpt)
3088        if any(scf_ngkpt % ph_ngqpt != 0):
3089            raise ValueError("ph_ngqpt %s should be a sub-mesh of scf_ngkpt %s" % (ph_ngqpt, scf_ngkpt))
3091        # Get the q-points in the IBZ from Abinit
3092        qpoints = scf_input.abiget_ibz(ngkpt=ph_ngqpt, shiftk=(0, 0, 0), kptopt=1).points
3094        # Create a PhononWork for each q-point. Add DDK and E-field if q == Gamma and with_becs.
3095        for qpt in qpoints:
3096            if np.allclose(qpt, 0) and with_becs:
3097                ph_work = BecWork.from_scf_task(scf_task)
3098            else:
3099                ph_work = PhononWork.from_scf_task(scf_task, qpoints=qpt)
3101            flow.register_work(ph_work)
3103        if allocate: flow.allocate()
3105        return flow
3107    def open_final_ddb(self):
3108        """
3109        Open the DDB file located in the output directory of the flow.
3111        Return:
3112            :class:`DdbFile` object, None if file could not be found or file is not readable.
3113        """
3114        ddb_path = self.outdir.has_abiext("DDB")
3115        if not ddb_path:
3116            if self.status == self.S_OK:
3117                self.history.critical("%s reached S_OK but didn't produce a GSR file in %s" % (self, self.outdir))
3118            return None
3120        from abipy.dfpt.ddb import DdbFile
3121        try:
3122            return DdbFile(ddb_path)
3123        except Exception as exc:
3124            self.history.critical("Exception while reading DDB file at %s:\n%s" % (ddb_path, str(exc)))
3125            return None
3127    def finalize(self):
3128        """This method is called when the flow is completed."""
3129        # Merge all the out_DDB files found in work.outdir.
3130        ddb_files = list(filter(None, [work.outdir.has_abiext("DDB") for work in self]))
3132        # Final DDB file will be produced in the outdir of the work.
3133        out_ddb = self.outdir.path_in("out_DDB")
3134        desc = "DDB file merged by %s on %s" % (self.__class__.__name__, time.asctime())
3136        mrgddb = wrappers.Mrgddb(manager=self.manager, verbose=0)
3137        mrgddb.merge(self.outdir.path, ddb_files, out_ddb=out_ddb, description=desc)
3138        print("Final DDB file available at %s" % out_ddb)
3140        # Call the method of the super class.
3141        retcode = super().finalize()
3142        return retcode
3145class NonLinearCoeffFlow(Flow):
3146    """
3147    1) One workflow for the GS run.
3149    2) nqpt works for electric field calculations. Each work contains
3150       nirred tasks where nirred is the number of irreducible perturbations
3151       for that particular q-point.
3152    """
3153    @classmethod
3154    def from_scf_input(cls, workdir, scf_input, manager=None, allocate=True):
3155        """
3156        Create a `NonlinearFlow` for second order susceptibility calculations from
3157        an `AbinitInput` defining a ground-state run.
3159        Args:
3160            workdir: Working directory of the flow.
3161            scf_input: |AbinitInput| object with the parameters for the GS-SCF run.
3162            manager: |TaskManager| object. Read from `manager.yml` if None.
3163            allocate: True if the flow should be allocated before returning.
3165        Return:
3166            :class:`NonlinearFlow` object.
3167        """
3168        flow = cls(workdir, manager=manager)
3170        flow.register_scf_task(scf_input)
3171        scf_task = flow[0][0]
3173        nl_work = DteWork.from_scf_task(scf_task)
3175        flow.register_work(nl_work)
3177        if allocate: flow.allocate()
3179        return flow
3181    def open_final_ddb(self):
3182        """
3183        Open the DDB file located in the output directory of the flow.
3185        Return:
3186            |DdbFile| object, None if file could not be found or file is not readable.
3187        """
3188        ddb_path = self.outdir.has_abiext("DDB")
3189        if not ddb_path:
3190            if self.status == self.S_OK:
3191                self.history.critical("%s reached S_OK but didn't produce a GSR file in %s" % (self, self.outdir))
3192            return None
3194        from abipy.dfpt.ddb import DdbFile
3195        try:
3196            return DdbFile(ddb_path)
3197        except Exception as exc:
3198            self.history.critical("Exception while reading DDB file at %s:\n%s" % (ddb_path, str(exc)))
3199            return None
3201    def finalize(self):
3202        """This method is called when the flow is completed."""
3203        # Merge all the out_DDB files found in work.outdir.
3204        ddb_files = list(filter(None, [work.outdir.has_abiext("DDB") for work in self]))
3206        # Final DDB file will be produced in the outdir of the work.
3207        out_ddb = self.outdir.path_in("out_DDB")
3208        desc = "DDB file merged by %s on %s" % (self.__class__.__name__, time.asctime())
3210        mrgddb = wrappers.Mrgddb(manager=self.manager, verbose=0)
3211        mrgddb.merge(self.outdir.path, ddb_files, out_ddb=out_ddb, description=desc)
3213        print("Final DDB file available at %s" % out_ddb)
3215        # Call the method of the super class.
3216        retcode = super().finalize()
3217        print("retcode", retcode)
3218        #if retcode != 0: return retcode
3219        return retcode
3222def phonon_conv_flow(workdir, scf_input, qpoints, params, manager=None, allocate=True):
3223    """
3224    Create a |Flow| to perform convergence studies for phonon calculations.
3226    Args:
3227        workdir: Working directory of the flow.
3228        scf_input: |AbinitInput| object defining a GS-SCF calculation.
3229        qpoints: List of list of lists with the reduced coordinates of the q-point(s).
3230        params:
3231            To perform a converge study wrt ecut: params=["ecut", [2, 4, 6]]
3232        manager: |TaskManager| object responsible for the submission of the jobs.
3233            If manager is None, the object is initialized from the yaml file
3234            located either in the working directory or in the user configuration dir.
3235        allocate: True if the flow should be allocated before returning.
3237    Return: |Flow| object.
3238    """
3239    qpoints = np.reshape(qpoints, (-1, 3))
3241    flow = Flow(workdir=workdir, manager=manager)
3243    for qpt in qpoints:
3244        for gs_inp in scf_input.product(*params):
3245            # Register the SCF task
3246            work = flow.register_scf_task(gs_inp)
3248            # Add the PhononWork connected to this scf_task.
3249            flow.register_work(PhononWork.from_scf_task(work[0], qpoints=qpt))
3251    if allocate: flow.allocate()
3252    return flow