1# coding: utf-8
2"""
3A Flow is a container for Works, and works consist of tasks.
4Flows are the final objects that can be dumped directly to a pickle file on disk
5Flows are executed using abirun (abipy).
6"""
7import os
8import sys
9import time
10import collections
11import warnings
12import shutil
13import tempfile
14import numpy as np
15
16from io import StringIO
17from pprint import pprint
18from tabulate import tabulate
19from pydispatch import dispatcher
20from collections import OrderedDict
21from monty.collections import dict2namedtuple
22from monty.string import list_strings, is_string, make_banner
23from monty.operator import operator_from_str
24from monty.io import FileLock
25from monty.pprint import draw_tree
26from monty.termcolor import cprint, colored, cprint_map, get_terminal_size
27from monty.inspect import find_top_pyfile
28from monty.json import MSONable
29from pymatgen.util.serialization import pmg_pickle_load, pmg_pickle_dump, pmg_serialize
30from pymatgen.core.units import Memory
31from pymatgen.util.io_utils import AtomicFile
32from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt
33from abipy.tools.printing import print_dataframe
34from abipy.flowtk import wrappers
35from .nodes import Status, Node, NodeError, NodeResults, Dependency, GarbageCollector, check_spectator
36from .tasks import ScfTask, TaskManager, FixQueueCriticalError
37from .utils import File, Directory, Editor
38from .works import NodeContainer, Work, BandStructureWork, PhononWork, BecWork, G0W0Work, QptdmWork, DteWork
39from .events import EventsParser
40
41__author__ = "Matteo Giantomassi"
42__copyright__ = "Copyright 2013, The Materials Project"
43__version__ = "0.1"
44__maintainer__ = "Matteo Giantomassi"
45
46
47__all__ = [
48    "Flow",
49    "G0W0WithQptdmFlow",
50    "bandstructure_flow",
51    "g0w0_flow",
52]
53
54
55def as_set(obj):
56    """
57    Convert obj into a set, returns None if obj is None.
58
59    >>> assert as_set(None) is None and as_set(1) == set([1]) and as_set(range(1,3)) == set([1, 2])
60    """
61    if obj is None or isinstance(obj, collections.abc.Set):
62        return obj
63
64    if not isinstance(obj, collections.abc.Iterable):
65        return set((obj,))
66    else:
67        return set(obj)
68
69
70class FlowResults(NodeResults):
71
72    JSON_SCHEMA = NodeResults.JSON_SCHEMA.copy()
73    #JSON_SCHEMA["properties"] = {
74    #    "queries": {"type": "string", "required": True},
75    #}
76
77    @classmethod
78    def from_node(cls, flow):
79        """Initialize an instance from a Work instance."""
80        new = super().from_node(flow)
81
82        # Will put all files found in outdir in GridFs
83        d = {os.path.basename(f): f for f in flow.outdir.list_filepaths()}
84
85        # Add the pickle file.
86        d["pickle"] = flow.pickle_file if flow.pickle_protocol != 0 else (flow.pickle_file, "t")
87        new.add_gridfs_files(**d)
88
89        return new
90
91
92class FlowError(NodeError):
93    """Base Exception for :class:`Node` methods"""
94
95
96class Flow(Node, NodeContainer, MSONable):
97    """
98    This object is a container of work. Its main task is managing the
99    possible inter-dependencies among the work and the creation of
100    dynamic workflows that are generated by callbacks registered by the user.
101
102    Attributes:
103
104        creation_date: String with the creation_date
105        pickle_protocol: Protocol for Pickle database (default: -1 i.e. latest protocol)
106
107    Important methods for constructing flows:
108
109        register_work: register (add) a work to the flow
110        resister_task: register a work that contains only this task returns the work
111        allocate: propagate the workdir and manager of the flow to all the registered tasks
112        build:
113        build_and_pickle_dump:
114    """
115    VERSION = "0.1"
116    PICKLE_FNAME = "__AbinitFlow__.pickle"
117
118    Error = FlowError
119
120    Results = FlowResults
121
122    @classmethod
123    def from_inputs(cls, workdir, inputs, manager=None, pickle_protocol=-1, task_class=ScfTask,
124                    work_class=Work, remove=False):
125
126        """
127        Construct a simple flow from a list of inputs. The flow contains a single Work with
128        tasks whose class is given by task_class.
129
130        .. warning::
131
132            Don't use this interface if you have dependencies among the tasks.
133
134        Args:
135            workdir: String specifying the directory where the works will be produced.
136            inputs: List of inputs.
137            manager: |TaskManager| object responsible for the submission of the jobs.
138                If manager is None, the object is initialized from the yaml file
139                located either in the working directory or in the user configuration dir.
140            pickle_protocol: Pickle protocol version used for saving the status of the object.
141                -1 denotes the latest version supported by the python interpreter.
142            task_class: The class of the |Task|.
143            work_class: The class of the |Work|.
144            remove: attempt to remove working directory `workdir` if directory already exists.
145        """
146        if not isinstance(inputs, (list, tuple)): inputs = [inputs]
147
148        flow = cls(workdir, manager=manager, pickle_protocol=pickle_protocol, remove=remove)
149        work = work_class()
150        for inp in inputs:
151            work.register(inp, task_class=task_class)
152        flow.register_work(work)
153
154        return flow.allocate()
155
156    @classmethod
157    def as_flow(cls, obj):
158        """Convert obj into a Flow. Accepts filepath, dict, or Flow object."""
159        if isinstance(obj, cls): return obj
160        if is_string(obj):
161            return cls.pickle_load(obj)
162        elif isinstance(obj, collections.abc.Mapping):
163            return cls.from_dict(obj)
164        else:
165            raise TypeError("Don't know how to convert type %s into a Flow" % type(obj))
166
167    def __init__(self, workdir, manager=None, pickle_protocol=-1, remove=False):
168        """
169        Args:
170            workdir: String specifying the directory where the works will be produced.
171                     if workdir is None, the initialization of the working directory
172                     is performed by flow.allocate(workdir).
173            manager: |TaskManager| object responsible for the submission of the jobs.
174                     If manager is None, the object is initialized from the yaml file
175                     located either in the working directory or in the user configuration dir.
176            pickle_protocol: Pickle protocol version used for saving the status of the object.
177                          -1 denotes the latest version supported by the python interpreter.
178            remove: attempt to remove working directory `workdir` if directory already exists.
179        """
180        super().__init__()
181
182        if workdir is not None:
183            if remove and os.path.exists(workdir): shutil.rmtree(workdir)
184            self.set_workdir(workdir)
185
186        self.creation_date = time.asctime()
187
188        if manager is None: manager = TaskManager.from_user_config()
189        self.manager = manager.deepcopy()
190
191        # List of works.
192        self._works = []
193
194        self._waited = 0
195
196        # List of callbacks that must be executed when the dependencies reach S_OK
197        self._callbacks = []
198
199        # Install default list of handlers at the flow level.
200        # Users can override the default list by calling flow.install_event_handlers in the script.
201        # Example:
202        #
203        #    # flow level (common case)
204        #    flow.install_event_handlers(handlers=my_handlers)
205        #
206        #    # task level (advanced mode)
207        #    flow[0][0].install_event_handlers(handlers=my_handlers)
208        #
209        self.install_event_handlers()
210
211        self.pickle_protocol = int(pickle_protocol)
212
213        # ID used to access mongodb
214        self._mongo_id = None
215
216        # Save the location of the script used to generate the flow.
217        # This trick won't work if we are running with nosetests, py.test etc
218        pyfile = find_top_pyfile()
219        if "python" in pyfile or "ipython" in pyfile: pyfile = "<" + pyfile + ">"
220        self.set_pyfile(pyfile)
221
222        # TODO
223        # Signal slots: a dictionary with the list
224        # of callbacks indexed by node_id and SIGNAL_TYPE.
225        # When the node changes its status, it broadcast a signal.
226        # The flow is listening to all the nodes of the calculation
227        # [node_id][SIGNAL] = list_of_signal_handlers
228        #self._sig_slots =  slots = {}
229        #for work in self:
230        #    slots[work] = {s: [] for s in work.S_ALL}
231
232        #for task in self.iflat_tasks():
233        #    slots[task] = {s: [] for s in work.S_ALL}
234
235        self.on_all_ok_num_calls = 0
236
237    @pmg_serialize
238    def as_dict(self, **kwargs):
239        """
240        JSON serialization, note that we only need to save
241        a string with the working directory since the object will be
242        reconstructed from the pickle file located in workdir
243        """
244        return {"workdir": self.workdir}
245
246    # This is needed for fireworks.
247    to_dict = as_dict
248
249    @classmethod
250    def from_dict(cls, d, **kwargs):
251        """Reconstruct the flow from the pickle file."""
252        return cls.pickle_load(d["workdir"], **kwargs)
253
254    @classmethod
255    def temporary_flow(cls, manager=None):
256        """Return a Flow in a temporary directory. Useful for unit tests."""
257        return cls(workdir=tempfile.mkdtemp(), manager=manager)
258
259    def set_workdir(self, workdir, chroot=False):
260        """
261        Set the working directory. Cannot be set more than once unless chroot is True
262        """
263        if not chroot and hasattr(self, "workdir") and self.workdir != workdir:
264            raise ValueError("self.workdir != workdir: %s, %s" % (self.workdir,  workdir))
265
266        # Directories with (input|output|temporary) data.
267        self.workdir = os.path.abspath(workdir)
268        self.indir = Directory(os.path.join(self.workdir, "indata"))
269        self.outdir = Directory(os.path.join(self.workdir, "outdata"))
270        self.tmpdir = Directory(os.path.join(self.workdir, "tmpdata"))
271        self.wdir = Directory(self.workdir)
272
273    def reload(self):
274        """
275        Reload the flow from the pickle file. Used when we are monitoring the flow
276        executed by the scheduler. In this case, indeed, the flow might have been changed
277        by the scheduler and we have to reload the new flow in memory.
278        """
279        new = self.__class__.pickle_load(self.workdir)
280        self = new
281
282    @classmethod
283    def pickle_load(cls, filepath, spectator_mode=True, remove_lock=False):
284        """
285        Loads the object from a pickle file and performs initial setup.
286
287        Args:
288            filepath: Filename or directory name. It filepath is a directory, we
289                scan the directory tree starting from filepath and we
290                read the first pickle database. Raise RuntimeError if multiple
291                databases are found.
292            spectator_mode: If True, the nodes of the flow are not connected by signals.
293                This option is usually used when we want to read a flow
294                in read-only mode and we want to avoid callbacks that can change the flow.
295            remove_lock:
296                True to remove the file lock if any (use it carefully).
297        """
298        if os.path.isdir(filepath):
299            # Walk through each directory inside path and find the pickle database.
300            for dirpath, dirnames, filenames in os.walk(filepath):
301                fnames = [f for f in filenames if f == cls.PICKLE_FNAME]
302                if fnames:
303                    if len(fnames) == 1:
304                        filepath = os.path.join(dirpath, fnames[0])
305                        break  # Exit os.walk
306                    else:
307                        err_msg = "Found multiple databases:\n %s" % str(fnames)
308                        raise RuntimeError(err_msg)
309            else:
310                err_msg = "Cannot find %s inside directory %s" % (cls.PICKLE_FNAME, filepath)
311                raise ValueError(err_msg)
312
313        if remove_lock and os.path.exists(filepath + ".lock"):
314            try:
315                os.remove(filepath + ".lock")
316            except Exception:
317                pass
318
319        with FileLock(filepath):
320            with open(filepath, "rb") as fh:
321                flow = pmg_pickle_load(fh)
322
323        # Check if versions match.
324        if flow.VERSION != cls.VERSION:
325            msg = ("File flow version %s != latest version %s\n."
326                   "Regenerate the flow to solve the problem " % (flow.VERSION, cls.VERSION))
327            warnings.warn(msg)
328
329        flow.set_spectator_mode(spectator_mode)
330
331        # Recompute the status of each task since tasks that
332        # have been submitted previously might be completed.
333        flow.check_status()
334        return flow
335
336    # Handy alias
337    from_file = pickle_load
338
339    @classmethod
340    def pickle_loads(cls, s):
341        """Reconstruct the flow from a string."""
342        strio = StringIO()
343        strio.write(s)
344        strio.seek(0)
345        flow = pmg_pickle_load(strio)
346        return flow
347
348    def get_panel(self):
349        """Build panel with widgets to interact with the |Flow| either in a notebook or in panel app."""
350        from abipy.panels.flows import FlowPanel
351        return FlowPanel(self).get_panel()
352
353    def __len__(self):
354        return len(self.works)
355
356    def __iter__(self):
357        return self.works.__iter__()
358
359    def __getitem__(self, slice):
360        return self.works[slice]
361
362    def set_pyfile(self, pyfile):
363        """
364        Set the path of the python script used to generate the flow.
365
366        .. Example:
367
368            flow.set_pyfile(__file__)
369        """
370        # TODO: Could use a frame hack to get the caller outside abinit
371        # so that pyfile is automatically set when we __init__ it!
372        self._pyfile = os.path.abspath(pyfile)
373
374    @property
375    def pyfile(self):
376        """
377        Absolute path of the python script used to generate the flow. Set by `set_pyfile`
378        """
379        try:
380            return self._pyfile
381        except AttributeError:
382            return None
383
384    @property
385    def pid_file(self):
386        """The path of the pid file created by PyFlowScheduler."""
387        return os.path.join(self.workdir, "_PyFlowScheduler.pid")
388
389    @property
390    def has_scheduler(self):
391        """True if there's a scheduler running the flow."""
392        return os.path.exists(self.pid_file)
393
394    def check_pid_file(self):
395        """
396        This function checks if we are already running the |Flow| with a :class:`PyFlowScheduler`.
397        Raises: Flow.Error if the pid file of the scheduler exists.
398        """
399        if not os.path.exists(self.pid_file):
400            return 0
401
402        self.show_status()
403        raise self.Error("""\n\
404            pid_file
405            %s
406            already exists. There are two possibilities:
407
408               1) There's an another instance of PyFlowScheduler running
409               2) The previous scheduler didn't exit in a clean way
410
411            To solve case 1:
412               Kill the previous scheduler (use 'kill pid' where pid is the number reported in the file)
413               Then you can restart the new scheduler.
414
415            To solve case 2:
416               Remove the pid_file and restart the scheduler.
417
418            Exiting""" % self.pid_file)
419
420    @property
421    def pickle_file(self):
422        """The path of the pickle file."""
423        return os.path.join(self.workdir, self.PICKLE_FNAME)
424
425    @property
426    def mongo_id(self):
427        return self._mongo_id
428
429    @mongo_id.setter
430    def mongo_id(self, value):
431        if self.mongo_id is not None:
432            raise RuntimeError("Cannot change mongo_id %s" % self.mongo_id)
433        self._mongo_id = value
434
435    #def mongodb_upload(self, **kwargs):
436    #    from abiflows.core.scheduler import FlowUploader
437    #    FlowUploader().upload(self, **kwargs)
438
439    def validate_json_schema(self):
440        """Validate the JSON schema. Return list of errors."""
441        errors = []
442
443        for work in self:
444            for task in work:
445                if not task.get_results().validate_json_schema():
446                    errors.append(task)
447            if not work.get_results().validate_json_schema():
448                errors.append(work)
449        if not self.get_results().validate_json_schema():
450            errors.append(self)
451
452        return errors
453
454    def get_mongo_info(self):
455        """
456        Return a JSON dictionary with information on the flow.
457        Mainly used for constructing the info section in `FlowEntry`.
458        The default implementation is empty. Subclasses must implement it
459        """
460        return {}
461
462    def mongo_assimilate(self):
463        """
464        This function is called by client code when the flow is completed
465        Return a JSON dictionary with the most important results produced
466        by the flow. The default implementation is empty. Subclasses must implement it
467        """
468        return {}
469
470    @property
471    def works(self):
472        """List of |Work| objects contained in self.."""
473        return self._works
474
475    @property
476    def all_ok(self):
477        """True if all the tasks in works have reached `S_OK`."""
478        all_ok = all(work.all_ok for work in self)
479        if all_ok:
480            all_ok = self.on_all_ok()
481        return all_ok
482
483    @property
484    def num_tasks(self):
485        """Total number of tasks"""
486        return len(list(self.iflat_tasks()))
487
488    @property
489    def errored_tasks(self):
490        """List of errored tasks."""
491        etasks = []
492        for status in [self.S_ERROR, self.S_QCRITICAL, self.S_ABICRITICAL]:
493            etasks.extend(list(self.iflat_tasks(status=status)))
494
495        return set(etasks)
496
497    @property
498    def num_errored_tasks(self):
499        """The number of tasks whose status is `S_ERROR`."""
500        return len(self.errored_tasks)
501
502    @property
503    def unconverged_tasks(self):
504        """List of unconverged tasks."""
505        return list(self.iflat_tasks(status=self.S_UNCONVERGED))
506
507    @property
508    def num_unconverged_tasks(self):
509        """The number of tasks whose status is `S_UNCONVERGED`."""
510        return len(self.unconverged_tasks)
511
512    @property
513    def status_counter(self):
514        """
515        Returns a :class:`Counter` object that counts the number of tasks with
516        given status (use the string representation of the status as key).
517        """
518        # Count the number of tasks with given status in each work.
519        counter = self[0].status_counter
520        for work in self[1:]:
521            counter += work.status_counter
522
523        return counter
524
525    @property
526    def ncores_reserved(self):
527        """
528        Returns the number of cores reserved in this moment.
529        A core is reserved if the task is not running but
530        we have submitted the task to the queue manager.
531        """
532        return sum(work.ncores_reserved for work in self)
533
534    @property
535    def ncores_allocated(self):
536        """
537        Returns the number of cores allocated in this moment.
538        A core is allocated if it's running a task or if we have
539        submitted a task to the queue manager but the job is still pending.
540        """
541        return sum(work.ncores_allocated for work in self)
542
543    @property
544    def ncores_used(self):
545        """
546        Returns the number of cores used in this moment.
547        A core is used if there's a job that is running on it.
548        """
549        return sum(work.ncores_used for work in self)
550
551    @property
552    def has_chrooted(self):
553        """
554        Returns a string that evaluates to True if we have changed
555        the workdir for visualization purposes e.g. we are using sshfs.
556        to mount the remote directory where the `Flow` is located.
557        The string gives the previous workdir of the flow.
558        """
559        try:
560            return self._chrooted_from
561        except AttributeError:
562            return ""
563
564    def chroot(self, new_workdir):
565        """
566        Change the workir of the |Flow|. Mainly used for
567        allowing the user to open the GUI on the local host
568        and access the flow from remote via sshfs.
569
570        .. note::
571            Calling this method will make the flow go in read-only mode.
572        """
573        self._chrooted_from = self.workdir
574        self.set_workdir(new_workdir, chroot=True)
575
576        for i, work in enumerate(self):
577            new_wdir = os.path.join(self.workdir, "w" + str(i))
578            work.chroot(new_wdir)
579
580    def groupby_status(self):
581        """
582        Returns a ordered dictionary mapping the task status to
583        the list of named tuples (task, work_index, task_index).
584        """
585        Entry = collections.namedtuple("Entry", "task wi ti")
586        d = collections.defaultdict(list)
587
588        for task, wi, ti in self.iflat_tasks_wti():
589            d[task.status].append(Entry(task, wi, ti))
590
591        # Sort keys according to their status.
592        return OrderedDict([(k, d[k]) for k in sorted(list(d.keys()))])
593
594    def groupby_task_class(self):
595        """
596        Returns a dictionary mapping the task class to the list of tasks in the flow
597        """
598        # Find all Task classes
599        class2tasks = OrderedDict()
600        for task in self.iflat_tasks():
601            cls = task.__class__
602            if cls not in class2tasks: class2tasks[cls] = []
603            class2tasks[cls].append(task)
604
605        return class2tasks
606
607    def iflat_nodes(self, status=None, op="==", nids=None):
608        """
609        Generators that produces a flat sequence of nodes.
610        if status is not None, only the tasks with the specified status are selected.
611        nids is an optional list of node identifiers used to filter the nodes.
612        """
613        nids = as_set(nids)
614
615        if status is None:
616            if not (nids and self.node_id not in nids):
617                yield self
618
619            for work in self:
620                if nids and work.node_id not in nids: continue
621                yield work
622                for task in work:
623                    if nids and task.node_id not in nids: continue
624                    yield task
625        else:
626            # Get the operator from the string.
627            op = operator_from_str(op)
628
629            # Accept Task.S_FLAG or string.
630            status = Status.as_status(status)
631
632            if not (nids and self.node_id not in nids):
633                if op(self.status, status): yield self
634
635            for wi, work in enumerate(self):
636                if nids and work.node_id not in nids: continue
637                if op(work.status, status): yield work
638
639                for ti, task in enumerate(work):
640                    if nids and task.node_id not in nids: continue
641                    if op(task.status, status): yield task
642
643    def node_from_nid(self, nid):
644        """Return the node in the `Flow` with the given `nid` identifier"""
645        for node in self.iflat_nodes():
646            if node.node_id == nid: return node
647        raise ValueError("Cannot find node with node id: %s" % nid)
648
649    def iflat_tasks_wti(self, status=None, op="==", nids=None):
650        """
651        Generator to iterate over all the tasks of the `Flow`.
652        Yields:
653
654            (task, work_index, task_index)
655
656        If status is not None, only the tasks whose status satisfies
657        the condition (task.status op status) are selected
658        status can be either one of the flags defined in the |Task| class
659        (e.g Task.S_OK) or a string e.g "S_OK"
660        nids is an optional list of node identifiers used to filter the tasks.
661        """
662        return self._iflat_tasks_wti(status=status, op=op, nids=nids, with_wti=True)
663
664    def iflat_tasks(self, status=None, op="==", nids=None):
665        """
666        Generator to iterate over all the tasks of the |Flow|.
667
668        If status is not None, only the tasks whose status satisfies
669        the condition (task.status op status) are selected
670        status can be either one of the flags defined in the |Task| class
671        (e.g Task.S_OK) or a string e.g "S_OK"
672        nids is an optional list of node identifiers used to filter the tasks.
673        """
674        return self._iflat_tasks_wti(status=status, op=op, nids=nids, with_wti=False)
675
676    def _iflat_tasks_wti(self, status=None, op="==", nids=None, with_wti=True):
677        """
678        Generators that produces a flat sequence of task.
679        if status is not None, only the tasks with the specified status are selected.
680        nids is an optional list of node identifiers used to filter the tasks.
681
682        Returns:
683            (task, work_index, task_index) if with_wti is True else task
684        """
685        nids = as_set(nids)
686
687        if status is None:
688            for wi, work in enumerate(self):
689                for ti, task in enumerate(work):
690                    if nids and task.node_id not in nids: continue
691                    if with_wti:
692                        yield task, wi, ti
693                    else:
694                        yield task
695
696        else:
697            # Get the operator from the string.
698            op = operator_from_str(op)
699
700            # Accept Task.S_FLAG or string.
701            status = Status.as_status(status)
702
703            for wi, work in enumerate(self):
704                for ti, task in enumerate(work):
705                    if nids and task.node_id not in nids: continue
706                    if op(task.status, status):
707                        if with_wti:
708                            yield task, wi, ti
709                        else:
710                            yield task
711
712    def abivalidate_inputs(self):
713        """
714        Run ABINIT in dry mode to validate all the inputs of the flow.
715
716        Return:
717            (isok, tuples)
718
719            isok is True if all inputs are ok.
720            tuples is List of `namedtuple` objects, one for each task in the flow.
721            Each namedtuple has the following attributes:
722
723                retcode: Return code. 0 if OK.
724                log_file:  log file of the Abinit run, use log_file.read() to access its content.
725                stderr_file: stderr file of the Abinit run. use stderr_file.read() to access its content.
726
727        Raises:
728            `RuntimeError` if executable is not in $PATH.
729        """
730        if not self.allocated:
731            self.allocate()
732
733        isok, tuples = True, []
734        for task in self.iflat_tasks():
735            t = task.input.abivalidate()
736            if t.retcode != 0: isok = False
737            tuples.append(t)
738
739        return isok, tuples
740
741    def check_dependencies(self):
742        """Test the dependencies of the nodes for possible deadlocks."""
743        deadlocks = []
744
745        for task in self.iflat_tasks():
746            for dep in task.deps:
747                if dep.node.depends_on(task):
748                    deadlocks.append((task, dep.node))
749
750        if deadlocks:
751            lines = ["Detect wrong list of dependecies that will lead to a deadlock:"]
752            lines.extend(["%s <--> %s" % nodes for nodes in deadlocks])
753            raise RuntimeError("\n".join(lines))
754
755    def find_deadlocks(self):
756        """
757        This function detects deadlocks
758
759        Return:
760            named tuple with the tasks grouped in: deadlocks, runnables, running
761        """
762        # Find jobs that can be submitted and and the jobs that are already in the queue.
763        runnables = []
764        for work in self:
765            runnables.extend(work.fetch_alltasks_to_run())
766        runnables.extend(list(self.iflat_tasks(status=self.S_SUB)))
767
768        # Running jobs.
769        running = list(self.iflat_tasks(status=self.S_RUN))
770
771        # Find deadlocks.
772        err_tasks = self.errored_tasks
773        deadlocked = []
774        if err_tasks:
775            for task in self.iflat_tasks():
776                if any(task.depends_on(err_task) for err_task in err_tasks):
777                    deadlocked.append(task)
778
779        return dict2namedtuple(deadlocked=deadlocked, runnables=runnables, running=running)
780
781    def check_status(self, **kwargs):
782        """
783        Check the status of the works in self.
784
785        Args:
786            show: True to show the status of the flow.
787            kwargs: keyword arguments passed to show_status
788        """
789        for work in self:
790            work.check_status()
791
792        if kwargs.pop("show", False):
793            self.show_status(**kwargs)
794
795    @property
796    def status(self):
797        """The status of the |Flow| i.e. the minimum of the status of its tasks and its works"""
798        return min(work.get_all_status(only_min=True) for work in self)
799
800    #def restart_unconverged_tasks(self, max_nlauch, excs):
801    #    nlaunch = 0
802    #    for task in self.unconverged_tasks:
803    #        try:
804    #            self.history.info("Flow will try restart task %s" % task)
805    #            fired = task.restart()
806    #            if fired:
807    #                nlaunch += 1
808    #                max_nlaunch -= 1
809
810    #                if max_nlaunch == 0:
811    #                    self.history.info("Restart: too many jobs in the queue, returning")
812    #                    self.pickle_dump()
813    #                    return nlaunch, max_nlaunch
814
815    #        except task.RestartError:
816    #            excs.append(straceback())
817
818    #    return nlaunch, max_nlaunch
819
820    def fix_abicritical(self):
821        """
822        This function tries to fix critical events originating from ABINIT.
823        Returns the number of tasks that have been fixed.
824        """
825        count = 0
826        for task in self.iflat_tasks(status=self.S_ABICRITICAL):
827            count += task.fix_abicritical()
828
829        return count
830
831    def fix_queue_critical(self):
832        """
833        This function tries to fix critical events originating from the queue submission system.
834
835        Returns the number of tasks that have been fixed.
836        """
837        count = 0
838        for task in self.iflat_tasks(status=self.S_QCRITICAL):
839            self.history.info("Will try to fix task %s" % str(task))
840            try:
841                print(task.fix_queue_critical())
842                count += 1
843            except FixQueueCriticalError:
844                self.history.info("Not able to fix task %s" % task)
845
846        return count
847
848    def show_info(self, **kwargs):
849        """
850        Print info on the flow i.e. total number of tasks, works, tasks grouped by class.
851
852        Example:
853
854            Task Class      Number
855            ------------  --------
856            ScfTask              1
857            NscfTask             1
858            ScrTask              2
859            SigmaTask            6
860        """
861        stream = kwargs.pop("stream", sys.stdout)
862
863        lines = [str(self)]
864        app = lines.append
865
866        app("Number of works: %d, total number of tasks: %s" % (len(self), self.num_tasks))
867        app("Number of tasks with a given class:\n")
868
869        # Build Table
870        data = [[cls.__name__, len(tasks)]
871                for cls, tasks in self.groupby_task_class().items()]
872        app(str(tabulate(data, headers=["Task Class", "Number"])))
873
874        stream.write("\n".join(lines))
875
876    def compare_abivars(self, varnames, nids=None, wslice=None, printout=False, with_colors=False):
877        """
878        Print the input of the tasks to the given stream.
879
880        Args:
881            varnames:
882                List of Abinit variables. If not None, only the variable in varnames
883                are selected and printed.
884            nids: List of node identifiers. By defaults all nodes are shown
885            wslice: Slice object used to select works.
886            printout: True to print dataframe.
887            with_colors: True if task status should be colored.
888        """
889        varnames = [s.strip() for s in list_strings(varnames)]
890        index, rows = [], []
891        for task in self.select_tasks(nids=nids, wslice=wslice):
892            index.append(task.pos_str)
893            dstruct = task.input.structure.as_dict(fmt="abivars")
894
895            od = OrderedDict()
896            for vname in varnames:
897                value = task.input.get(vname, None)
898                if value is None: # maybe in structure?
899                    value = dstruct.get(vname, None)
900                od[vname] = value
901
902            od["task_class"] = task.__class__.__name__
903            od["status"] = task.status.colored if with_colors else str(task.status)
904            rows.append(od)
905
906        import pandas as pd
907        df = pd.DataFrame(rows, index=index)
908        if printout:
909            print_dataframe(df, title="Input variables:")
910
911        return df
912
913    def get_dims_dataframe(self, nids=None, printout=False, with_colors=False):
914        """
915        Analyze output files produced by the tasks. Print pandas DataFrame with dimensions.
916
917        Args:
918            nids: List of node identifiers. By defaults all nodes are shown
919            printout: True to print dataframe.
920            with_colors: True if task status should be colored.
921        """
922        abo_paths, index, status, abo_relpaths, task_classes, task_nids = [], [], [], [], [], []
923
924        for task in self.iflat_tasks(nids=nids):
925            if task.status not in (self.S_OK, self.S_RUN): continue
926            if not task.is_abinit_task: continue
927
928            abo_paths.append(task.output_file.path)
929            index.append(task.pos_str)
930            status.append(task.status.colored if with_colors else str(task.status))
931            abo_relpaths.append(os.path.relpath(task.output_file.relpath))
932            task_classes.append(task.__class__.__name__)
933            task_nids.append(task.node_id)
934
935        if not abo_paths: return
936
937        # Get dimensions from output files as well as walltime/cputime
938        from abipy.abio.outputs import AboRobot
939        robot = AboRobot.from_files(abo_paths)
940        df = robot.get_dims_dataframe(with_time=True, index=index)
941
942        # Add columns to the dataframe.
943        status = [str(s) for s in status]
944        df["task_class"] = task_classes
945        df["relpath"] = abo_relpaths
946        df["node_id"] = task_nids
947        df["status"] = status
948
949        if printout:
950            print_dataframe(df, title="Table with Abinit dimensions:\n")
951
952        return df
953
954    def compare_structures(self, nids=None, with_spglib=False, what="io", verbose=0,
955                           precision=3, printout=False, with_colors=False):
956        """
957        Analyze structures of the tasks (input and output structures if it's a relaxation task.
958        Print pandas DataFrame
959
960        Args:
961            nids: List of node identifiers. By defaults all nodes are shown
962            with_spglib: If True, spglib is invoked to get the spacegroup symbol and number
963            what (str): "i" for input structures, "o" for output structures.
964            precision: Floating point output precision (number of significant digits).
965                This is only a suggestion
966            printout: True to print dataframe.
967            with_colors: True if task status should be colored.
968        """
969        from abipy.core.structure import dataframes_from_structures
970        structures, index, status, max_forces, pressures, task_classes = [], [], [], [], [], []
971
972        def push_data(post, task, structure, cart_forces, pressure):
973            """Helper function to fill lists"""
974            index.append(task.pos_str + post)
975            structures.append(structure)
976            status.append(task.status.colored if with_colors else str(task.status))
977            if cart_forces is not None:
978                fmods = np.sqrt([np.dot(f, f) for f in cart_forces])
979                max_forces.append(fmods.max())
980            else:
981                max_forces.append(None)
982            pressures.append(pressure)
983            task_classes.append(task.__class__.__name__)
984
985        for task in self.iflat_tasks(nids=nids):
986            if "i" in what:
987                push_data("_in", task, task.input.structure, cart_forces=None, pressure=None)
988
989            if "o" not in what:
990                continue
991
992            # Add final structure, pressure and max force if relaxation task or GS task
993            if task.status in (task.S_RUN, task.S_OK):
994                if hasattr(task, "open_hist"):
995                    # Structural relaxations produce HIST.nc and we can get
996                    # the final structure or the structure of the last relaxation step.
997                    try:
998                        with task.open_hist() as hist:
999                            final_structure = hist.final_structure
1000                            stress_cart_tensors, pressures_hist = hist.reader.read_cart_stress_tensors()
1001                            forces = hist.reader.read_cart_forces(unit="eV ang^-1")[-1]
1002                            push_data("_out", task, final_structure, forces, pressures_hist[-1])
1003                    except Exception as exc:
1004                        cprint("Exception while opening HIST.nc file of task: %s\n%s" % (task, str(exc)), "red")
1005
1006                elif hasattr(task, "open_gsr") and task.status == task.S_OK and task.input.get("iscf", 7) >= 0:
1007                    with task.open_gsr() as gsr:
1008                        forces = gsr.reader.read_cart_forces(unit="eV ang^-1")
1009                        push_data("_out", task, gsr.structure, forces, gsr.pressure)
1010
1011        dfs = dataframes_from_structures(structures, index=index, with_spglib=with_spglib, cart_coords=False)
1012
1013        if any(f is not None for f in max_forces):
1014            # Add pressure and forces to the dataframe
1015            dfs.lattice["P [GPa]"] = pressures
1016            dfs.lattice["Max|F| eV/ang"] = max_forces
1017
1018        # Add columns to the dataframe.
1019        status = [str(s) for s in status]
1020        dfs.lattice["task_class"] = task_classes
1021        dfs.lattice["status"] = dfs.coords["status"] = status
1022
1023        if printout:
1024            print_dataframe(dfs.lattice, title="Lattice parameters:", precision=precision)
1025            if verbose:
1026                print_dataframe(dfs.coords, title="Atomic positions (columns give the site index):")
1027            else:
1028                print("Use `--verbose` to print atoms.")
1029
1030        return dfs
1031
1032    def compare_ebands(self, nids=None, with_path=True, with_ibz=True, with_spglib=False, verbose=0,
1033                       precision=3, printout=False, with_colors=False):
1034        """
1035        Analyze electron bands produced by the tasks.
1036        Return pandas DataFrame and |ElectronBandsPlotter|.
1037
1038        Args:
1039            nids: List of node identifiers. By default, all nodes are shown
1040            with_path: Select files with ebands along k-path.
1041            with_ibz: Select files with ebands in the IBZ.
1042            with_spglib: If True, spglib is invoked to get the spacegroup symbol and number
1043            precision: Floating point output precision (number of significant digits).
1044                This is only a suggestion
1045            printout: True to print dataframe.
1046            with_colors: True if task status should be colored.
1047
1048        Return: (df, ebands_plotter)
1049        """
1050        ebands_list, index, status, ncfiles, task_classes, task_nids = [], [], [], [], [], []
1051
1052        # Cannot use robots because ElectronBands can be found in different filetypes
1053        for task in self.iflat_tasks(nids=nids, status=self.S_OK):
1054            # Read ebands from GSR or SIGRES files.
1055            for ext in ("gsr", "sigres"):
1056                task_open_ncfile = getattr(task, "open_%s" % ext, None)
1057                if task_open_ncfile is not None: break
1058            else:
1059                continue
1060
1061            try:
1062                with task_open_ncfile() as ncfile:
1063                    if not with_path and ncfile.ebands.kpoints.is_path: continue
1064                    if not with_ibz and ncfile.ebands.kpoints.is_ibz: continue
1065                    ebands_list.append(ncfile.ebands)
1066                    index.append(task.pos_str)
1067                    status.append(task.status.colored if with_colors else str(task.status))
1068                    ncfiles.append(os.path.relpath(ncfile.filepath))
1069                    task_classes.append(task.__class__.__name__)
1070                    task_nids.append(task.node_id)
1071            except Exception as exc:
1072                cprint("Exception while opening nc file of task: %s\n%s" % (task, str(exc)), "red")
1073
1074        if not ebands_list: return (None, None)
1075
1076        from abipy.electrons.ebands import dataframe_from_ebands
1077        df = dataframe_from_ebands(ebands_list, index=index, with_spglib=with_spglib)
1078        ncfiles = [os.path.relpath(p, self.workdir) for p in ncfiles]
1079
1080        # Add columns to the dataframe.
1081        status = [str(s) for s in status]
1082        df["task_class"] = task_classes
1083        df["ncfile"] = ncfiles
1084        df["node_id"] = task_nids
1085        df["status"] = status
1086
1087        if printout:
1088            from abipy.tools.printing import print_dataframe
1089            print_dataframe(df, title="KS electronic bands:", precision=precision)
1090
1091        from abipy.electrons.ebands import ElectronBandsPlotter
1092        ebands_plotter = ElectronBandsPlotter(key_ebands=zip(ncfiles, ebands_list))
1093
1094        return df, ebands_plotter
1095
1096    def compare_hist(self, nids=None, with_spglib=False, verbose=0,
1097                     precision=3, printout=False, with_colors=False):
1098        """
1099        Analyze HIST nc files produced by the tasks. Print pandas DataFrame with final results.
1100        Return: (df, hist_plotter)
1101
1102        Args:
1103            nids: List of node identifiers. By defaults all nodes are shown
1104            with_spglib: If True, spglib is invoked to get the spacegroup symbol and number
1105            precision: Floating point output precision (number of significant digits).
1106                This is only a suggestion
1107            printout: True to print dataframe.
1108            with_colors: True if task status should be colored.
1109        """
1110        hist_paths, index, status, ncfiles, task_classes, task_nids = [], [], [], [], [], []
1111
1112        for task in self.iflat_tasks(nids=nids):
1113            if task.status not in (self.S_OK, self.S_RUN): continue
1114            hist_path = task.outdir.has_abiext("HIST")
1115            if not hist_path: continue
1116
1117            hist_paths.append(hist_path)
1118            index.append(task.pos_str)
1119            status.append(task.status.colored if with_colors else str(task.status))
1120            ncfiles.append(os.path.relpath(hist_path))
1121            task_classes.append(task.__class__.__name__)
1122            task_nids.append(task.node_id)
1123
1124        if not hist_paths: return (None, None)
1125        from abipy.dynamics.hist import HistRobot
1126        robot = HistRobot.from_files(hist_paths, labels=hist_paths)
1127        df = robot.get_dataframe(index=index, with_spglib=with_spglib)
1128        ncfiles = [os.path.relpath(p, self.workdir) for p in ncfiles]
1129
1130        # Add columns to the dataframe.
1131        status = [str(s) for s in status]
1132        df["task_class"] = task_classes
1133        df["ncfile"] = ncfiles
1134        df["node_id"] = task_nids
1135        df["status"] = status
1136
1137        if printout:
1138            title = "Table with final structures, pressures in GPa and force stats in eV/Ang:\n"
1139            from abipy.tools.printing import print_dataframe
1140            print_dataframe(df, title=title, precision=precision)
1141
1142        return df, robot
1143
1144    def show_summary(self, **kwargs):
1145        """
1146        Print a short summary with the status of the flow and a counter task_status --> number_of_tasks
1147
1148        Args:
1149            stream: File-like object, Default: sys.stdout
1150
1151        Example:
1152
1153            Status       Count
1154            ---------  -------
1155            Completed       10
1156
1157            <Flow, node_id=27163, workdir=flow_gwconv_ecuteps>, num_tasks=10, all_ok=True
1158        """
1159        stream = kwargs.pop("stream", sys.stdout)
1160        stream.write("\n")
1161        table = list(self.status_counter.items())
1162        s = tabulate(table, headers=["Status", "Count"])
1163        stream.write(s + "\n")
1164        stream.write("\n")
1165        stream.write("%s, num_tasks=%s, all_ok=%s\n" % (str(self), self.num_tasks, self.all_ok))
1166        stream.write("\n")
1167
1168    def show_status(self, **kwargs):
1169        """
1170        Report the status of the works and the status of the different tasks on the specified stream.
1171
1172        Args:
1173            stream: File-like object, Default: sys.stdout
1174            nids:  List of node identifiers. By defaults all nodes are shown
1175            wslice: Slice object used to select works.
1176            verbose: Verbosity level (default 0). > 0 to show only the works that are not finalized.
1177        """
1178        stream = kwargs.pop("stream", sys.stdout)
1179        nids = as_set(kwargs.pop("nids", None))
1180        wslice = kwargs.pop("wslice", None)
1181        verbose = kwargs.pop("verbose", 0)
1182        wlist = None
1183        if wslice is not None:
1184            # Convert range to list of work indices.
1185            wlist = list(range(wslice.start, wslice.step, wslice.stop))
1186
1187        #has_colours = stream_has_colours(stream)
1188        has_colours = True
1189        red = "red" if has_colours else None
1190
1191        for i, work in enumerate(self):
1192            if nids and work.node_id not in nids: continue
1193            print("", file=stream)
1194            cprint_map("Work #%d: %s, Finalized=%s" % (i, work, work.finalized), cmap={"True": "green"}, file=stream)
1195            if wlist is not None and i in wlist: continue
1196            if verbose == 0 and work.finalized:
1197                print("  Finalized works are not shown. Use verbose > 0 to force output.", file=stream)
1198                continue
1199
1200            headers = ["Task", "Status", "Queue", "MPI|Omp|Gb",
1201                       "Warn|Com", "Class", "Sub|Rest|Corr", "Time",
1202                       "Node_ID"]
1203            table = []
1204            tot_num_errors = 0
1205            for task in work:
1206                if nids and task.node_id not in nids: continue
1207                task_name = os.path.basename(task.name)
1208
1209                # FIXME: This should not be done here.
1210                # get_event_report should be called only in check_status
1211                # Parse the events in the main output.
1212                report = task.get_event_report()
1213
1214                # Get time info (run-time or time in queue or None)
1215                stime = None
1216                timedelta = task.datetimes.get_runtime()
1217                if timedelta is not None:
1218                    stime = str(timedelta) + "R"
1219                else:
1220                    timedelta = task.datetimes.get_time_inqueue()
1221                    if timedelta is not None:
1222                        stime = str(timedelta) + "Q"
1223
1224                events = "|".join(2*["NA"])
1225                if report is not None:
1226                    events = '{:>4}|{:>3}'.format(*map(str, (
1227                       report.num_warnings, report.num_comments)))
1228
1229                para_info = '{:>4}|{:>3}|{:>3}'.format(*map(str, (
1230                   task.mpi_procs, task.omp_threads, "%.1f" % task.mem_per_proc.to("Gb"))))
1231
1232                task_info = list(map(str, [task.__class__.__name__,
1233                                 (task.num_launches, task.num_restarts, task.num_corrections), stime, task.node_id]))
1234
1235                qinfo = "None"
1236                if task.queue_id is not None:
1237                    qname = str(task.qname)
1238                    if not verbose:
1239                        qname = qname[:min(5, len(qname))]
1240                    qinfo = str(task.queue_id) + "@" + qname
1241
1242                if task.status.is_critical:
1243                    tot_num_errors += 1
1244                    task_name = colored(task_name, red)
1245
1246                if has_colours:
1247                    table.append([task_name, task.status.colored, qinfo,
1248                                  para_info, events] + task_info)
1249                else:
1250                    table.append([task_name, str(task.status), qinfo, events,
1251                                  para_info] + task_info)
1252
1253            # Print table and write colorized line with the total number of errors.
1254            print(tabulate(table, headers=headers, tablefmt="grid"), file=stream)
1255            if tot_num_errors:
1256                cprint("Total number of errors: %d" % tot_num_errors, "red", file=stream)
1257            print("", file=stream)
1258
1259        if self.all_ok:
1260            cprint("\nall_ok reached\n", "green", file=stream)
1261
1262    def show_events(self, status=None, nids=None, stream=sys.stdout):
1263        """
1264        Print the Abinit events (ERRORS, WARNIING, COMMENTS) to stdout
1265
1266        Args:
1267            status: if not None, only the tasks with this status are select
1268            nids: optional list of node identifiers used to filter the tasks.
1269            stream: File-like object, Default: sys.stdout
1270        """
1271        nrows, ncols = get_terminal_size()
1272
1273        for task in self.iflat_tasks(status=status, nids=nids):
1274            report = task.get_event_report()
1275            if report:
1276                print(make_banner(str(task), width=ncols, mark="="), file=stream)
1277                print(report, file=stream)
1278                #report = report.filter_types()
1279
1280    def show_corrections(self, status=None, nids=None, stream=sys.stdout):
1281        """
1282        Show the corrections applied to the flow at run-time.
1283
1284        Args:
1285            status: if not None, only the tasks with this status are select.
1286            nids: optional list of node identifiers used to filter the tasks.
1287            stream: File-like object, Default: sys.stdout
1288
1289        Return: The number of corrections found.
1290        """
1291        nrows, ncols = get_terminal_size()
1292        count = 0
1293        for task in self.iflat_tasks(status=status, nids=nids):
1294            if task.num_corrections == 0: continue
1295            count += 1
1296            print(make_banner(str(task), width=ncols, mark="="), file=stream)
1297            for corr in task.corrections:
1298                pprint(corr, stream=stream)
1299
1300        if not count: print("No correction found.", file=stream)
1301        return count
1302
1303    def show_history(self, status=None, nids=None, full_history=False, metadata=False, stream=sys.stdout):
1304        """
1305        Print the history of the flow to stream
1306
1307        Args:
1308            status: if not None, only the tasks with this status are select
1309            full_history: Print full info set, including nodes with an empty history.
1310            nids: optional list of node identifiers used to filter the tasks.
1311            metadata: print history metadata (experimental)
1312            stream: File-like object, Default: sys.stdout
1313        """
1314        nrows, ncols = get_terminal_size()
1315
1316        works_done = []
1317        # Loop on the tasks and show the history of the work is not in works_done
1318        for task in self.iflat_tasks(status=status, nids=nids):
1319            work = task.work
1320
1321            if work not in works_done:
1322                works_done.append(work)
1323                if work.history or full_history:
1324                    cprint(make_banner(str(work), width=ncols, mark="="), file=stream, **work.status.color_opts)
1325                    print(work.history.to_string(metadata=metadata), file=stream)
1326
1327            if task.history or full_history:
1328                cprint(make_banner(str(task), width=ncols, mark="="), file=stream, **task.status.color_opts)
1329                print(task.history.to_string(metadata=metadata), file=stream)
1330
1331        # Print the history of the flow.
1332        if self.history or full_history:
1333            cprint(make_banner(str(self), width=ncols, mark="="), file=stream, **self.status.color_opts)
1334            print(self.history.to_string(metadata=metadata), file=stream)
1335
1336    def show_inputs(self, varnames=None, nids=None, wslice=None, stream=sys.stdout):
1337        """
1338        Print the input of the tasks to the given stream.
1339
1340        Args:
1341            varnames: List of Abinit variables. If not None, only the variable in varnames
1342                are selected and printed.
1343            nids: List of node identifiers. By defaults all nodes are shown
1344            wslice: Slice object used to select works.
1345            stream: File-like object, Default: sys.stdout
1346        """
1347        if varnames is not None:
1348            # Build dictionary varname --> [(task1, value1), (task2, value2), ...]
1349            varnames = [s.strip() for s in list_strings(varnames)]
1350            dlist = collections.defaultdict(list)
1351            for task in self.select_tasks(nids=nids, wslice=wslice):
1352                dstruct = task.input.structure.as_dict(fmt="abivars")
1353
1354                for vname in varnames:
1355                    value = task.input.get(vname, None)
1356                    if value is None: # maybe in structure?
1357                        value = dstruct.get(vname, None)
1358                    if value is not None:
1359                        dlist[vname].append((task, value))
1360
1361            for vname in varnames:
1362                tv_list = dlist[vname]
1363                if not tv_list:
1364                    stream.write("[%s]: Found 0 tasks with this variable\n" % vname)
1365                else:
1366                    stream.write("[%s]: Found %s tasks with this variable\n" % (vname, len(tv_list)))
1367                    for i, (task, value) in enumerate(tv_list):
1368                        stream.write("   %s --> %s\n" % (str(value), task))
1369                stream.write("\n")
1370
1371        else:
1372            lines = []
1373            for task in self.select_tasks(nids=nids, wslice=wslice):
1374                s = task.make_input(with_header=True)
1375
1376                # Add info on dependencies.
1377                if task.deps:
1378                    s += "\n\nDependencies:\n" + "\n".join(str(dep) for dep in task.deps)
1379                else:
1380                    s += "\n\nDependencies: None"
1381
1382                lines.append(2*"\n" + 80 * "=" + "\n" + s + 2*"\n")
1383
1384            stream.writelines(lines)
1385
1386    def listext(self, ext, stream=sys.stdout):
1387        """
1388        Print to the given `stream` a table with the list of the output files
1389        with the given `ext` produced by the flow.
1390        """
1391        nodes_files = []
1392        for node in self.iflat_nodes():
1393            filepath = node.outdir.has_abiext(ext)
1394            if filepath:
1395                nodes_files.append((node, File(filepath)))
1396
1397        if nodes_files:
1398            print("Found %s files with extension `%s` produced by the flow" % (len(nodes_files), ext), file=stream)
1399
1400            table = [[f.relpath, "%.2f" % (f.get_stat().st_size / 1024**2),
1401                      node.node_id, node.__class__.__name__]
1402                     for node, f in nodes_files]
1403            print(tabulate(table, headers=["File", "Size [Mb]", "Node_ID", "Node Class"]), file=stream)
1404
1405        else:
1406            print("No output file with extension %s has been produced by the flow" % ext, file=stream)
1407
1408    def select_tasks(self, nids=None, wslice=None, task_class=None):
1409        """
1410        Return a list with a subset of tasks.
1411
1412        Args:
1413            nids: List of node identifiers.
1414            wslice: Slice object used to select works.
1415            task_class: String or class used to select tasks. Ignored if None.
1416
1417        .. note::
1418
1419            nids and wslice are mutually exclusive.
1420            If no argument is provided, the full list of tasks is returned.
1421        """
1422        if nids is not None:
1423            assert wslice is None
1424            tasks = self.tasks_from_nids(nids)
1425
1426        elif wslice is not None:
1427            tasks = []
1428            for work in self[wslice]:
1429                tasks.extend([t for t in work])
1430        else:
1431            # All tasks selected if no option is provided.
1432            tasks = list(self.iflat_tasks())
1433
1434        # Filter by task class
1435        if task_class is not None:
1436            tasks = [t for t in tasks if t.isinstance(task_class)]
1437
1438        return tasks
1439
1440    def get_task_scfcycles(self, nids=None, wslice=None, task_class=None, exclude_ok_tasks=False):
1441        """
1442        Return list of (taks, scfcycle) tuples for all the tasks in the flow with a SCF algorithm
1443        e.g. electronic GS-SCF iteration, DFPT-SCF iterations etc.
1444
1445        Args:
1446            nids: List of node identifiers.
1447            wslice: Slice object used to select works.
1448            task_class: String or class used to select tasks. Ignored if None.
1449            exclude_ok_tasks: True if only running tasks should be considered.
1450
1451        Returns:
1452            List of `ScfCycle` subclass instances.
1453        """
1454        select_status = [self.S_RUN] if exclude_ok_tasks else [self.S_RUN, self.S_OK]
1455        tasks_cycles = []
1456
1457        for task in self.select_tasks(nids=nids, wslice=wslice):
1458            # Fileter
1459            if task.status not in select_status or task.cycle_class is None:
1460                continue
1461            if task_class is not None and not task.isinstance(task_class):
1462                continue
1463            try:
1464                cycle = task.cycle_class.from_file(task.output_file.path)
1465                if cycle is not None:
1466                    tasks_cycles.append((task, cycle))
1467            except Exception:
1468                # This is intentionally ignored because from_file can fail for several reasons.
1469                pass
1470
1471        return tasks_cycles
1472
1473    def show_tricky_tasks(self, verbose=0, stream=sys.stdout):
1474        """
1475        Print list of tricky tasks i.e. tasks that have been restarted or
1476        launched more than once or tasks with corrections.
1477
1478        Args:
1479            verbose: Verbosity level. If > 0, task history and corrections (if any) are printed.
1480            stream: File-like object. Default: sys.stdout
1481        """
1482        nids, tasks = [], []
1483        for task in self.iflat_tasks():
1484            if task.num_launches > 1 or any(n > 0 for n in (task.num_restarts, task.num_corrections)):
1485                nids.append(task.node_id)
1486                tasks.append(task)
1487
1488        if not nids:
1489            cprint("Everything's fine, no tricky tasks found", color="green", file=stream)
1490        else:
1491            self.show_status(nids=nids, stream=stream)
1492            if not verbose:
1493                print("Use --verbose to print task history.", file=stream)
1494                return
1495
1496            for nid, task in zip(nids, tasks):
1497                cprint(repr(task), **task.status.color_opts, stream=stream)
1498                self.show_history(nids=[nid], full_history=False, metadata=False, stream=stream)
1499                #if task.num_restarts:
1500                #    self.show_restarts(nids=[nid])
1501                if task.num_corrections:
1502                    self.show_corrections(nids=[nid], stream=stream)
1503
1504    def inspect(self, nids=None, wslice=None, **kwargs):
1505        """
1506        Inspect the tasks (SCF iterations, Structural relaxation ...) and
1507        produces matplotlib plots.
1508
1509        Args:
1510            nids: List of node identifiers.
1511            wslice: Slice object used to select works.
1512            kwargs: keyword arguments passed to `task.inspect` method.
1513
1514        .. note::
1515
1516            nids and wslice are mutually exclusive.
1517            If nids and wslice are both None, all tasks in self are inspected.
1518
1519        Returns:
1520            List of `matplotlib` figures.
1521        """
1522        figs = []
1523        for task in self.select_tasks(nids=nids, wslice=wslice):
1524            if hasattr(task, "inspect"):
1525                fig = task.inspect(**kwargs)
1526                if fig is None:
1527                    cprint("Cannot inspect Task %s" % task, color="blue")
1528                else:
1529                    figs.append(fig)
1530            else:
1531                cprint("Task %s does not provide an inspect method" % task, color="blue")
1532
1533        return figs
1534
1535    def get_results(self, **kwargs):
1536        results = self.Results.from_node(self)
1537        results.update(self.get_dict_for_mongodb_queries())
1538        return results
1539
1540    def get_dict_for_mongodb_queries(self):
1541        """
1542        This function returns a dictionary with the attributes that will be
1543        put in the mongodb document to facilitate the query.
1544        Subclasses may want to replace or extend the default behaviour.
1545        """
1546        d = {}
1547        return d
1548        # TODO
1549        all_structures = [task.input.structure for task in self.iflat_tasks()]
1550        all_pseudos = [task.input.pseudos for task in self.iflat_tasks()]
1551
1552    def look_before_you_leap(self):
1553        """
1554        This method should be called before running the calculation to make
1555        sure that the most important requirements are satisfied.
1556
1557        Return:
1558            List of strings with inconsistencies/errors.
1559        """
1560        errors = []
1561
1562        try:
1563            self.check_dependencies()
1564        except self.Error as exc:
1565            errors.append(str(exc))
1566
1567        if self.has_db:
1568            try:
1569                self.manager.db_connector.get_collection()
1570            except Exception as exc:
1571                errors.append("""
1572                    ERROR while trying to connect to the MongoDB database:
1573                        Exception:
1574                            %s
1575                        Connector:
1576                            %s
1577                    """ % (exc, self.manager.db_connector))
1578
1579        return "\n".join(errors)
1580
1581    @property
1582    def has_db(self):
1583        """True if flow uses `MongoDB` to store the results."""
1584        return self.manager.has_db
1585
1586    def db_insert(self):
1587        """
1588        Insert results in the `MongDB` database.
1589        """
1590        assert self.has_db
1591        # Connect to MongoDb and get the collection.
1592        coll = self.manager.db_connector.get_collection()
1593        print("Mongodb collection %s with count %d", coll, coll.count())
1594
1595        start = time.time()
1596        for work in self:
1597            for task in work:
1598                results = task.get_results()
1599                pprint(results)
1600                results.update_collection(coll)
1601            results = work.get_results()
1602            pprint(results)
1603            results.update_collection(coll)
1604        print("MongoDb update done in %s [s]" % time.time() - start)
1605
1606        results = self.get_results()
1607        pprint(results)
1608        results.update_collection(coll)
1609
1610        # Update the pickle file to save the mongo ids.
1611        self.pickle_dump()
1612
1613        for d in coll.find():
1614            pprint(d)
1615
1616    def tasks_from_nids(self, nids):
1617        """
1618        Return the list of tasks associated to the given list of node identifiers (nids).
1619
1620        .. note::
1621
1622            Invalid ids are ignored
1623        """
1624        if not isinstance(nids, collections.abc.Iterable): nids = [nids]
1625
1626        n2task = {task.node_id: task for task in self.iflat_tasks()}
1627        return [n2task[n] for n in nids if n in n2task]
1628
1629    def wti_from_nids(self, nids):
1630        """Return the list of (w, t) indices from the list of node identifiers nids."""
1631        return [task.pos for task in self.tasks_from_nids(nids)]
1632
1633    def open_files(self, what="o", status=None, op="==", nids=None, editor=None):
1634        """
1635        Open the files of the flow inside an editor (command line interface).
1636
1637        Args:
1638            what: string with the list of characters selecting the file type
1639                  Possible choices:
1640
1641                    i ==> input_file,
1642                    o ==> output_file,
1643                    f ==> files_file,
1644                    j ==> job_file,
1645                    l ==> log_file,
1646                    e ==> stderr_file,
1647                    q ==> qout_file,
1648                    all ==> all files.
1649
1650            status: if not None, only the tasks with this status are select
1651            op: status operator. Requires status. A task is selected
1652                if task.status op status evaluates to true.
1653            nids: optional list of node identifiers used to filter the tasks.
1654            editor: Select the editor. None to use the default editor ($EDITOR shell env var)
1655        """
1656        # Build list of files to analyze.
1657        files = []
1658        for task in self.iflat_tasks(status=status, op=op, nids=nids):
1659            lst = task.select_files(what)
1660            if lst:
1661                files.extend(lst)
1662
1663        return Editor(editor=editor).edit_files(files)
1664
1665    def parse_timing(self, nids=None):
1666        """
1667        Parse the timer data in the main output file(s) of Abinit.
1668        Requires timopt /= 0 in the input file (usually timopt = -1)
1669
1670        Args:
1671            nids: optional list of node identifiers used to filter the tasks.
1672
1673        Return: :class:`AbinitTimerParser` instance, None if error.
1674        """
1675        # Get the list of output files according to nids.
1676        paths = [task.output_file.path for task in self.iflat_tasks(nids=nids)]
1677
1678        # Parse data.
1679        from abipy.flowtk.abitimer import AbinitTimerParser
1680        parser = AbinitTimerParser()
1681        read_ok = parser.parse(paths)
1682        if read_ok:
1683            return parser
1684        return None
1685
1686    def show_abierrors(self, nids=None, stream=sys.stdout):
1687        """
1688        Write to the given stream the list of ABINIT errors for all tasks whose status is S_ABICRITICAL.
1689
1690        Args:
1691            nids: optional list of node identifiers used to filter the tasks.
1692            stream: File-like object. Default: sys.stdout
1693        """
1694        lines = []
1695        app = lines.append
1696
1697        for task in self.iflat_tasks(status=self.S_ABICRITICAL, nids=nids):
1698            header = "=== " + task.qout_file.path + "==="
1699            app(header)
1700            report = task.get_event_report()
1701
1702            if report is not None:
1703                app("num_errors: %s, num_warnings: %s, num_comments: %s" % (
1704                    report.num_errors, report.num_warnings, report.num_comments))
1705                app("*** ERRORS ***")
1706                app("\n".join(str(e) for e in report.errors))
1707                app("*** BUGS ***")
1708                app("\n".join(str(b) for b in report.bugs))
1709
1710            else:
1711                app("get_envent_report returned None!")
1712
1713            app("=" * len(header) + 2*"\n")
1714
1715        return stream.writelines(lines)
1716
1717    def show_qouts(self, nids=None, stream=sys.stdout):
1718        """
1719        Write to the given stream the content of the queue output file for all tasks whose status is S_QCRITICAL.
1720
1721        Args:
1722            nids: optional list of node identifiers used to filter the tasks.
1723            stream: File-like object. Default: sys.stdout
1724        """
1725        lines = []
1726
1727        for task in self.iflat_tasks(status=self.S_QCRITICAL, nids=nids):
1728            header = "=== " + task.qout_file.path + "==="
1729            lines.append(header)
1730            if task.qout_file.exists:
1731                with open(task.qout_file.path, "rt") as fh:
1732                    lines += fh.readlines()
1733            else:
1734                lines.append("File does not exist!")
1735
1736            lines.append("=" * len(header) + 2*"\n")
1737
1738        return stream.writelines(lines)
1739
1740    def debug(self, status=None, nids=None, stream=sys.stdout):
1741        """
1742        This method is usually used when the flow didn't completed succesfully
1743        It analyzes the files produced the tasks to facilitate debugging.
1744        Info are printed to stdout.
1745
1746        Args:
1747            status: If not None, only the tasks with this status are selected
1748            nids: optional list of node identifiers used to filter the tasks.
1749            stream: File-like object. Default: sys.stdout
1750        """
1751        nrows, ncols = get_terminal_size()
1752
1753        # Test for scheduler exceptions first.
1754        sched_excfile = os.path.join(self.workdir, "_exceptions")
1755        if os.path.exists(sched_excfile):
1756            with open(sched_excfile, "r") as fh:
1757                cprint("Found exceptions raised by the scheduler", "red", file=stream)
1758                cprint(fh.read(), color="red", file=stream)
1759                return
1760
1761        if status is not None:
1762            tasks = list(self.iflat_tasks(status=status, nids=nids))
1763        else:
1764            errors = list(self.iflat_tasks(status=self.S_ERROR, nids=nids))
1765            qcriticals = list(self.iflat_tasks(status=self.S_QCRITICAL, nids=nids))
1766            abicriticals = list(self.iflat_tasks(status=self.S_ABICRITICAL, nids=nids))
1767            tasks = errors + qcriticals + abicriticals
1768
1769        # For each task selected:
1770        #     1) Check the error files of the task. If not empty, print the content to stdout and we are done.
1771        #     2) If error files are empty, look at the master log file for possible errors
1772        #     3) If also this check failes, scan all the process log files.
1773        #        TODO: This check is not needed if we introduce a new __abinit_error__ file
1774        #        that is created by the first MPI process that invokes MPI abort!
1775        #
1776        ntasks = 0
1777        for task in tasks:
1778            print(make_banner(str(task), width=ncols, mark="="), file=stream)
1779            ntasks += 1
1780
1781            #  Start with error files.
1782            for efname in ["qerr_file", "stderr_file",]:
1783                err_file = getattr(task, efname)
1784                if err_file.exists:
1785                    s = err_file.read()
1786                    if not s: continue
1787                    print(make_banner(str(err_file), width=ncols, mark="="), file=stream)
1788                    cprint(s, color="red", file=stream)
1789                    #count += 1
1790
1791            # Check main log file.
1792            try:
1793                report = task.get_event_report()
1794                if report and report.num_errors:
1795                    print(make_banner(os.path.basename(report.filename), width=ncols, mark="="), file=stream)
1796                    s = "\n".join(str(e) for e in report.errors)
1797                else:
1798                    s = None
1799            except Exception as exc:
1800                s = str(exc)
1801
1802            count = 0 # count > 0 means we found some useful info that could explain the failures.
1803            if s is not None:
1804                cprint(s, color="red", file=stream)
1805                count += 1
1806
1807            if not count:
1808                # Inspect all log files produced by the other nodes.
1809                log_files = task.tmpdir.list_filepaths(wildcard="*LOG_*")
1810                if not log_files:
1811                    cprint("No *LOG_* file in tmpdir. This usually happens if you are running with many CPUs",
1812                           color="magenta", file=stream)
1813
1814                for log_file in log_files:
1815                    try:
1816                        report = EventsParser().parse(log_file)
1817                        if report.errors:
1818                            print(report, file=stream)
1819                            count += 1
1820                            break
1821                    except Exception as exc:
1822                        cprint(str(exc), color="red", file=stream)
1823                        count += 1
1824                        break
1825
1826            if not count:
1827                cprint("Houston, we could not find any error message that can explain the problem",
1828                        color="magenta", file=stream)
1829
1830        print("Number of tasks analyzed: %d" % ntasks, file=stream)
1831
1832    def cancel(self, nids=None):
1833        """
1834        Cancel all the tasks that are in the queue.
1835        nids is an optional list of node identifiers used to filter the tasks.
1836
1837        Returns:
1838            Number of jobs cancelled, negative value if error
1839        """
1840        if self.has_chrooted:
1841            # TODO: Use paramiko to kill the job?
1842            warnings.warn("Cannot cancel the flow via sshfs!")
1843            return -1
1844
1845        # If we are running with the scheduler, we must send a SIGKILL signal.
1846        if os.path.exists(self.pid_file):
1847            cprint("Found scheduler attached to this flow.", "yellow")
1848            cprint("Sending SIGKILL to the scheduler before cancelling the tasks!", "yellow")
1849
1850            with open(self.pid_file, "rt") as fh:
1851                pid = int(fh.readline())
1852
1853            retcode = os.system("kill -9 %d" % pid)
1854            self.history.info("Sent SIGKILL to the scheduler, retcode: %s" % retcode)
1855            try:
1856                os.remove(self.pid_file)
1857            except IOError:
1858                pass
1859
1860        num_cancelled = 0
1861        for task in self.iflat_tasks(nids=nids):
1862            num_cancelled += task.cancel()
1863
1864        return num_cancelled
1865
1866    def get_njobs_in_queue(self, username=None):
1867        """
1868        returns the number of jobs in the queue, None when the number of jobs cannot be determined.
1869
1870        Args:
1871            username: (str) the username of the jobs to count (default is to autodetect)
1872        """
1873        return self.manager.qadapter.get_njobs_in_queue(username=username)
1874
1875    def rmtree(self, ignore_errors=False, onerror=None):
1876        """Remove workdir (same API as shutil.rmtree)."""
1877        if not os.path.exists(self.workdir): return
1878        shutil.rmtree(self.workdir, ignore_errors=ignore_errors, onerror=onerror)
1879
1880    def rm_and_build(self):
1881        """Remove the workdir and rebuild the flow."""
1882        self.rmtree()
1883        self.build()
1884
1885    def build(self, *args, **kwargs):
1886        """Make directories and files of the `Flow`."""
1887        # Allocate here if not done yet!
1888        if not self.allocated: self.allocate()
1889
1890        self.indir.makedirs()
1891        self.outdir.makedirs()
1892        self.tmpdir.makedirs()
1893
1894        # Check the nodeid file in workdir
1895        nodeid_path = os.path.join(self.workdir, ".nodeid")
1896
1897        if os.path.exists(nodeid_path):
1898            with open(nodeid_path, "rt") as fh:
1899                node_id = int(fh.read())
1900
1901            if self.node_id != node_id:
1902                msg = ("\nFound node_id %s in file:\n\n  %s\n\nwhile the node_id of the present flow is %d.\n"
1903                       "This means that you are trying to build a new flow in a directory already used by another flow.\n"
1904                       "Possible solutions:\n"
1905                       "   1) Change the workdir of the new flow.\n"
1906                       "   2) remove the old directory either with `rm -r` or by calling the method flow.rmtree()\n"
1907                       % (node_id, nodeid_path, self.node_id))
1908                raise RuntimeError(msg)
1909
1910        else:
1911            with open(nodeid_path, "wt") as fh:
1912                fh.write(str(self.node_id))
1913
1914        if self.pyfile and os.path.isfile(self.pyfile):
1915            shutil.copy(self.pyfile, self.workdir)
1916
1917        for work in self:
1918            work.build(*args, **kwargs)
1919
1920    def build_and_pickle_dump(self, abivalidate=False):
1921        """
1922        Build dirs and file of the `Flow` and save the object in pickle format.
1923        Returns 0 if success
1924
1925        Args:
1926            abivalidate: If True, all the input files are validate by calling
1927                the abinit parser. If the validation fails, ValueError is raise.
1928        """
1929        self.build()
1930        if not abivalidate: return self.pickle_dump()
1931
1932        # Validation with Abinit.
1933        isok, errors = self.abivalidate_inputs()
1934        if isok: return self.pickle_dump()
1935        errlines = []
1936        for i, e in enumerate(errors):
1937            errlines.append("[%d] %s" % (i, e))
1938        raise ValueError("\n".join(errlines))
1939
1940    @check_spectator
1941    def pickle_dump(self):
1942        """
1943        Save the status of the object in pickle format.
1944        Returns 0 if success
1945        """
1946        if self.has_chrooted:
1947            warnings.warn("Cannot pickle_dump since we have chrooted from %s" % self.has_chrooted)
1948            return -1
1949
1950        #if self.in_spectator_mode:
1951        #    warnings.warn("Cannot pickle_dump since flow is in_spectator_mode")
1952        #    return -2
1953
1954        protocol = self.pickle_protocol
1955
1956        # Atomic transaction with FileLock.
1957        with FileLock(self.pickle_file):
1958            with AtomicFile(self.pickle_file, mode="wb") as fh:
1959                pmg_pickle_dump(self, fh, protocol=protocol)
1960
1961        return 0
1962
1963    def pickle_dumps(self, protocol=None):
1964        """
1965        Return a string with the pickle representation.
1966        `protocol` selects the pickle protocol. self.pickle_protocol is used if `protocol` is None
1967        """
1968        strio = StringIO()
1969        pmg_pickle_dump(self, strio,
1970                        protocol=self.pickle_protocol if protocol is None
1971                        else protocol)
1972        return strio.getvalue()
1973
1974    def register_task(self, input, deps=None, manager=None, task_class=None, append=False):
1975        """
1976        Utility function that generates a `Work` made of a single task
1977
1978        Args:
1979            input: |AbinitInput|
1980            deps: List of :class:`Dependency` objects specifying the dependency of this node.
1981                  An empy list of deps implies that this node has no dependencies.
1982            manager: The |TaskManager| responsible for the submission of the task.
1983                     If manager is None, we use the |TaskManager| specified during the creation of the work.
1984            task_class: Task subclass to instantiate. Default: |AbinitTask|
1985            append: If true, the task is added to the last work (a new Work is created if flow is empty)
1986
1987        Returns:
1988            The generated |Work| for the task, work[0] is the actual task.
1989        """
1990        # append True is much easier to use. In principle should be the default behaviour
1991        # but this would break the previous API so ...
1992        if not append:
1993            work = Work(manager=manager)
1994        else:
1995            if not self.works:
1996                work = Work(manager=manager)
1997                append = False
1998            else:
1999                work = self.works[-1]
2000
2001        task = work.register(input, deps=deps, task_class=task_class)
2002        if not append: self.register_work(work)
2003
2004        return work
2005
2006    def register_work(self, work, deps=None, manager=None, workdir=None):
2007        """
2008        Register a new |Work| and add it to the internal list, taking into account possible dependencies.
2009
2010        Args:
2011            work: |Work| object.
2012            deps: List of :class:`Dependency` objects specifying the dependency of this node.
2013                  An empy list of deps implies that this node has no dependencies.
2014            manager: The |TaskManager| responsible for the submission of the task.
2015                     If manager is None, we use the `TaskManager` specified during the creation of the work.
2016            workdir: The name of the directory used for the |Work|.
2017
2018        Returns:
2019            The registered |Work|.
2020        """
2021        if getattr(self, "workdir", None) is not None:
2022            # The flow has a directory, build the name of the directory of the work.
2023            work_workdir = None
2024            if workdir is None:
2025                work_workdir = os.path.join(self.workdir, "w" + str(len(self)))
2026            else:
2027                work_workdir = os.path.join(self.workdir, os.path.basename(workdir))
2028
2029            work.set_workdir(work_workdir)
2030
2031        if manager is not None:
2032            work.set_manager(manager)
2033
2034        self.works.append(work)
2035
2036        if deps:
2037            deps = [Dependency(node, exts) for node, exts in deps.items()]
2038            work.add_deps(deps)
2039
2040        return work
2041
2042    def register_work_from_cbk(self, cbk_name, cbk_data, deps, work_class, manager=None):
2043        """
2044        Registers a callback function that will generate the :class:`Task` of the :class:`Work`.
2045
2046        Args:
2047            cbk_name: Name of the callback function (must be a bound method of self)
2048            cbk_data: Additional data passed to the callback function.
2049            deps: List of :class:`Dependency` objects specifying the dependency of the work.
2050            work_class: |Work| class to instantiate.
2051            manager: The |TaskManager| responsible for the submission of the task.
2052                    If manager is None, we use the `TaskManager` specified during the creation of the |Flow|.
2053
2054        Returns:
2055            The |Work| that will be finalized by the callback.
2056        """
2057        # TODO: pass a Work factory instead of a class
2058        # Directory of the Work.
2059        work_workdir = os.path.join(self.workdir, "w" + str(len(self)))
2060
2061        # Create an empty work and register the callback
2062        work = work_class(workdir=work_workdir, manager=manager)
2063
2064        self._works.append(work)
2065
2066        deps = [Dependency(node, exts) for node, exts in deps.items()]
2067        if not deps:
2068            raise ValueError("A callback must have deps!")
2069
2070        work.add_deps(deps)
2071
2072        # Wrap the callable in a Callback object and save
2073        # useful info such as the index of the work and the callback data.
2074        cbk = FlowCallback(cbk_name, self, deps=deps, cbk_data=cbk_data)
2075        self._callbacks.append(cbk)
2076
2077        return work
2078
2079    @property
2080    def allocated(self):
2081        """Numer of allocations. Set by `allocate`."""
2082        try:
2083            return self._allocated
2084        except AttributeError:
2085            return 0
2086
2087    def allocate(self, workdir=None, use_smartio=False):
2088        """
2089        Allocate the `Flow` i.e. assign the `workdir` and (optionally)
2090        the |TaskManager| to the different tasks in the Flow.
2091
2092        Args:
2093            workdir: Working directory of the flow. Must be specified here
2094                if we haven't initialized the workdir in the __init__.
2095
2096        Return:
2097            self
2098        """
2099        if workdir is not None:
2100            # We set the workdir of the flow here
2101            self.set_workdir(workdir)
2102            for i, work in enumerate(self):
2103                work.set_workdir(os.path.join(self.workdir, "w" + str(i)))
2104
2105        if not hasattr(self, "workdir"):
2106            raise RuntimeError("You must call flow.allocate(workdir) if the workdir is not passed to __init__")
2107
2108        for work in self:
2109            # Each work has a reference to its flow.
2110            work.allocate(manager=self.manager)
2111            work.set_flow(self)
2112            # Each task has a reference to its work.
2113            for task in work:
2114                task.set_work(work)
2115
2116        self.check_dependencies()
2117
2118        if not hasattr(self, "_allocated"): self._allocated = 0
2119        self._allocated += 1
2120
2121        if use_smartio:
2122            self.use_smartio()
2123
2124        return self
2125
2126    def use_smartio(self):
2127        """
2128        This function should be called when the entire `Flow` has been built.
2129        It tries to reduce the pressure on the hard disk by using Abinit smart-io
2130        capabilities for those files that are not needed by other nodes.
2131        Smart-io means that big files (e.g. WFK) are written only if the calculation
2132        is unconverged so that we can restart from it. No output is produced if
2133        convergence is achieved.
2134
2135        Return: self
2136        """
2137        if not self.allocated:
2138            #raise RuntimeError("You must call flow.allocate() before invoking flow.use_smartio()")
2139            self.allocate()
2140
2141        for task in self.iflat_tasks():
2142            children = task.get_children()
2143            if not children:
2144                # Change the input so that output files are produced
2145                # only if the calculation is not converged.
2146                task.history.info("Will disable IO for task")
2147                task.set_vars(prtwf=-1, prtden=0) # TODO: prt1wf=-1,
2148            else:
2149                must_produce_abiexts = []
2150                for child in children:
2151                    # Get the list of dependencies. Find that task
2152                    for d in child.deps:
2153                        must_produce_abiexts.extend(d.exts)
2154
2155                must_produce_abiexts = set(must_produce_abiexts)
2156                #print("must_produce_abiexts", must_produce_abiexts)
2157
2158                # Variables supporting smart-io.
2159                smart_prtvars = {
2160                    "prtwf": "WFK",
2161                }
2162
2163                # Set the variable to -1 to disable the output
2164                for varname, abiext in smart_prtvars.items():
2165                    if abiext not in must_produce_abiexts:
2166                        print("%s: setting %s to -1" % (task, varname))
2167                        task.set_vars({varname: -1})
2168
2169        return self
2170
2171    def show_dependencies(self, stream=sys.stdout):
2172        """
2173        Writes to the given stream the ASCII representation of the dependency tree.
2174        """
2175        def child_iter(node):
2176            return [d.node for d in node.deps]
2177
2178        def text_str(node):
2179            return colored(str(node), color=node.status.color_opts["color"])
2180
2181        for task in self.iflat_tasks():
2182            print(draw_tree(task, child_iter, text_str), file=stream)
2183
2184    def on_all_ok(self):
2185        """
2186        This method is called when all the works in the flow have reached S_OK.
2187        This method shall return True if the calculation is completed or
2188        False if the execution should continue due to side-effects such as adding a new work to the flow.
2189
2190        This methods allows subclasses to implement customized logic such as extending the flow by adding new works.
2191        The flow has an internal counter: on_all_ok_num_calls
2192        that shall be incremented by client code when subclassing this method.
2193        This counter can be used to decide if futher actions are needed or not.
2194
2195        An example of flow that adds a new work (only once) when all_ok is reached for the first time:
2196
2197        def on_all_ok(self):
2198            if self.on_all_ok_num_calls > 0: return True
2199            self.on_all_ok_num_calls += 1
2200
2201            `implement_logic_to_create_new_work`
2202
2203            self.register_work(work)
2204            self.allocate()
2205            self.build_and_pickle_dump()
2206
2207            return False # The scheduler will keep on running the flow.
2208        """
2209        return True
2210
2211    def on_dep_ok(self, signal, sender):
2212        # TODO
2213        # Replace this callback with dynamic dispatch
2214        # on_all_S_OK for work
2215        # on_S_OK for task
2216        self.history.info("on_dep_ok with sender %s, signal %s" % (str(sender), signal))
2217
2218        for i, cbk in enumerate(self._callbacks):
2219            if not cbk.handle_sender(sender):
2220                self.history.info("%s does not handle sender %s" % (cbk, sender))
2221                continue
2222
2223            if not cbk.can_execute():
2224                self.history.info("Cannot execute %s" % cbk)
2225                continue
2226
2227            # Execute the callback and disable it
2228            self.history.info("flow in on_dep_ok: about to execute callback %s" % str(cbk))
2229            cbk()
2230            cbk.disable()
2231
2232            # Update the database.
2233            self.pickle_dump()
2234
2235    @check_spectator
2236    def finalize(self):
2237        """
2238        This method is called when the flow is completed. Return 0 if success
2239        """
2240        self.history.info("Calling flow.finalize.")
2241        if self.finalized:
2242            self.history.warning("Calling finalize on an already finalized flow. Returning 1")
2243            return 1
2244
2245        self.finalized = True
2246
2247        if self.has_db:
2248            self.history.info("Saving results in database.")
2249            try:
2250                self.flow.db_insert()
2251                self.finalized = True
2252            except Exception:
2253                self.history.critical("MongoDb insertion failed.")
2254                return 2
2255
2256        # Here we remove the big output files if we have the garbage collector
2257        # and the policy is set to "flow."
2258        if self.gc is not None and self.gc.policy == "flow":
2259            self.history.info("gc.policy set to flow. Will clean task output files.")
2260            for task in self.iflat_tasks():
2261                task.clean_output_files()
2262
2263        return 0
2264
2265    def set_garbage_collector(self, exts=None, policy="task"):
2266        """
2267        Enable the garbage collector that will remove the big output files that are not needed.
2268
2269        Args:
2270            exts: string or list with the Abinit file extensions to be removed. A default is
2271                provided if exts is None
2272            policy: Either `flow` or `task`. If policy is set to 'task', we remove the output
2273                files as soon as the task reaches S_OK. If 'flow', the files are removed
2274                only when the flow is finalized. This option should be used when we are dealing
2275                with a dynamic flow with callbacks generating other tasks since a |Task|
2276                might not be aware of its children when it reached S_OK.
2277        """
2278        assert policy in ("task", "flow")
2279        exts = list_strings(exts) if exts is not None else ("WFK", "SUS", "SCR", "BSR", "BSC")
2280
2281        gc = GarbageCollector(exts=set(exts), policy=policy)
2282
2283        self.set_gc(gc)
2284        for work in self:
2285            #work.set_gc(gc) # TODO Add support for Works and flow policy
2286            for task in work:
2287                task.set_gc(gc)
2288
2289    def connect_signals(self):
2290        """
2291        Connect the signals within the `Flow`.
2292        The `Flow` is responsible for catching the important signals raised from its works.
2293        """
2294        # Connect the signals inside each Work.
2295        for work in self:
2296            work.connect_signals()
2297
2298        # Observe the nodes that must reach S_OK in order to call the callbacks.
2299        for cbk in self._callbacks:
2300            #cbk.enable()
2301            for dep in cbk.deps:
2302                self.history.info("Connecting %s \nwith sender %s, signal %s" % (str(cbk), dep.node, dep.node.S_OK))
2303                dispatcher.connect(self.on_dep_ok, signal=dep.node.S_OK, sender=dep.node, weak=False)
2304
2305        # Associate to each signal the callback _on_signal
2306        # (bound method of the node that will be called by `Flow`
2307        # Each node will set its attribute _done_signal to True to tell
2308        # the flow that this callback should be disabled.
2309
2310        # Register the callbacks for the Work.
2311        #for work in self:
2312        #    slot = self._sig_slots[work]
2313        #    for signal in S_ALL:
2314        #        done_signal = getattr(work, "_done_ " + signal, False)
2315        #        if not done_sig:
2316        #            cbk_name = "_on_" + str(signal)
2317        #            cbk = getattr(work, cbk_name, None)
2318        #            if cbk is None: continue
2319        #            slot[work][signal].append(cbk)
2320        #            print("connecting %s\nwith sender %s, signal %s" % (str(cbk), dep.node, dep.node.S_OK))
2321        #            dispatcher.connect(self.on_dep_ok, signal=signal, sender=dep.node, weak=False)
2322
2323        # Register the callbacks for the Tasks.
2324        #self.show_receivers()
2325
2326    def disconnect_signals(self):
2327        """Disable the signals within the `Flow`."""
2328        # Disconnect the signals inside each Work.
2329        for work in self:
2330            work.disconnect_signals()
2331
2332        # Disable callbacks.
2333        for cbk in self._callbacks:
2334            cbk.disable()
2335
2336    def show_receivers(self, sender=None, signal=None):
2337        sender = sender if sender is not None else dispatcher.Any
2338        signal = signal if signal is not None else dispatcher.Any
2339        print("*** live receivers ***")
2340        for rec in dispatcher.liveReceivers(dispatcher.getReceivers(sender, signal)):
2341            print("receiver -->", rec)
2342        print("*** end live receivers ***")
2343
2344    def set_spectator_mode(self, mode=True):
2345        """
2346        When the flow is in spectator_mode, we have to disable signals, pickle dump and possible callbacks
2347        A spectator can still operate on the flow but the new status of the flow won't be saved in
2348        the pickle file. Usually the flow is in spectator mode when we are already running it via
2349        the scheduler or other means and we should not interfere with its evolution.
2350        This is the reason why signals and callbacks must be disabled.
2351        Unfortunately preventing client-code from calling methods with side-effects when
2352        the flow is in spectator mode is not easy (e.g. flow.cancel will cancel the tasks submitted to the
2353        queue and the flow used by the scheduler won't see this change!
2354        """
2355        # Set the flags of all the nodes in the flow.
2356        mode = bool(mode)
2357        self.in_spectator_mode = mode
2358        for node in self.iflat_nodes():
2359            node.in_spectator_mode = mode
2360
2361        # connect/disconnect signals depending on mode.
2362        if not mode:
2363            self.connect_signals()
2364        else:
2365            self.disconnect_signals()
2366
2367    #def get_results(self, **kwargs)
2368
2369    def rapidfire(self, check_status=True, max_nlaunch=-1, max_loops=1, sleep_time=5, **kwargs):
2370        """
2371        Use :class:`PyLauncher` to submits tasks in rapidfire mode.
2372        kwargs contains the options passed to the launcher.
2373
2374        Args:
2375            check_status:
2376            max_nlaunch: Maximum number of launches. default: no limit.
2377            max_loops: Maximum number of loops
2378            sleep_time: seconds to sleep between rapidfire loop iterations
2379
2380        Return: Number of tasks submitted.
2381        """
2382        self.check_pid_file()
2383        self.set_spectator_mode(False)
2384        if check_status: self.check_status()
2385        from .launcher import PyLauncher
2386        return PyLauncher(self, **kwargs).rapidfire(max_nlaunch=max_nlaunch, max_loops=max_loops, sleep_time=sleep_time)
2387
2388    def single_shot(self, check_status=True, **kwargs):
2389        """
2390        Use :class:`PyLauncher` to submits one task.
2391        kwargs contains the options passed to the launcher.
2392
2393        Return: Number of tasks submitted.
2394        """
2395        self.check_pid_file()
2396        self.set_spectator_mode(False)
2397        if check_status: self.check_status()
2398        from .launcher import PyLauncher
2399        return PyLauncher(self, **kwargs).single_shot()
2400
2401    def make_scheduler(self, **kwargs):
2402        """
2403        Build and return a :class:`PyFlowScheduler` to run the flow.
2404
2405        Args:
2406            kwargs: if empty we use the user configuration file.
2407                    if `filepath` in kwargs we init the scheduler from filepath.
2408                    else pass kwargs to :class:`PyFlowScheduler` __init__ method.
2409        """
2410        from .launcher import PyFlowScheduler
2411        if not kwargs:
2412            # User config if kwargs is empty
2413            sched = PyFlowScheduler.from_user_config()
2414        else:
2415            # Use from_file if filepath if present, else call __init__
2416            filepath = kwargs.pop("filepath", None)
2417            if filepath is not None:
2418                assert not kwargs
2419                sched = PyFlowScheduler.from_file(filepath)
2420            else:
2421                sched = PyFlowScheduler(**kwargs)
2422
2423        sched.add_flow(self)
2424        return sched
2425
2426    def batch(self, timelimit=None):
2427        """
2428        Run the flow in batch mode, return exit status of the job script.
2429        Requires a manager.yml file and a batch_adapter adapter.
2430
2431        Args:
2432            timelimit: Time limit (int with seconds or string with time given with the slurm convention:
2433            "days-hours:minutes:seconds"). If timelimit is None, the default value specified in the
2434            `batch_adapter` entry of `manager.yml` is used.
2435        """
2436        from .launcher import BatchLauncher
2437        # Create a batch dir from the flow.workdir.
2438        prev_dir = os.path.join(*self.workdir.split(os.path.sep)[:-1])
2439        prev_dir = os.path.join(os.path.sep, prev_dir)
2440        workdir = os.path.join(prev_dir, os.path.basename(self.workdir) + "_batch")
2441
2442        return BatchLauncher(workdir=workdir, flows=self).submit(timelimit=timelimit)
2443
2444    def make_light_tarfile(self, name=None):
2445        """Lightweight tarball file. Mainly used for debugging. Return the name of the tarball file."""
2446        name = os.path.basename(self.workdir) + "-light.tar.gz" if name is None else name
2447        return self.make_tarfile(name=name, exclude_dirs=["outdata", "indata", "tmpdata"])
2448
2449    def make_tarfile(self, name=None, max_filesize=None, exclude_exts=None, exclude_dirs=None, verbose=0, **kwargs):
2450        """
2451        Create a tarball file.
2452
2453        Args:
2454            name: Name of the tarball file. Set to os.path.basename(`flow.workdir`) + "tar.gz"` if name is None.
2455            max_filesize (int or string with unit): a file is included in the tar file if its size <= max_filesize
2456                Can be specified in bytes e.g. `max_files=1024` or with a string with unit e.g. `max_filesize="1 Mb"`.
2457                No check is done if max_filesize is None.
2458            exclude_exts: List of file extensions to be excluded from the tar file.
2459            exclude_dirs: List of directory basenames to be excluded.
2460            verbose (int): Verbosity level.
2461            kwargs: keyword arguments passed to the :class:`TarFile` constructor.
2462
2463        Returns: The name of the tarfile.
2464        """
2465        def any2bytes(s):
2466            """Convert string or number to memory in bytes."""
2467            if is_string(s):
2468                return int(Memory.from_string(s).to("b"))
2469            else:
2470                return int(s)
2471
2472        if max_filesize is not None:
2473            max_filesize = any2bytes(max_filesize)
2474
2475        if exclude_exts:
2476            # Add/remove ".nc" so that we can simply pass "GSR" instead of "GSR.nc"
2477            # Moreover this trick allows one to treat WFK.nc and WFK file on the same footing.
2478            exts = []
2479            for e in list_strings(exclude_exts):
2480                exts.append(e)
2481                if e.endswith(".nc"):
2482                    exts.append(e.replace(".nc", ""))
2483                else:
2484                    exts.append(e + ".nc")
2485            exclude_exts = exts
2486
2487        def filter(tarinfo):
2488            """
2489            Function that takes a TarInfo object argument and returns the changed TarInfo object.
2490            If it instead returns None the TarInfo object will be excluded from the archive.
2491            """
2492            # Skip links.
2493            if tarinfo.issym() or tarinfo.islnk():
2494                if verbose: print("Excluding link: %s" % tarinfo.name)
2495                return None
2496
2497            # Check size in bytes
2498            if max_filesize is not None and tarinfo.size > max_filesize:
2499                if verbose: print("Excluding %s due to max_filesize" % tarinfo.name)
2500                return None
2501
2502            # Filter filenames.
2503            if exclude_exts and any(tarinfo.name.endswith(ext) for ext in exclude_exts):
2504                if verbose: print("Excluding %s due to extension" % tarinfo.name)
2505                return None
2506
2507            # Exlude directories (use dir basenames).
2508            if exclude_dirs and any(dir_name in exclude_dirs for dir_name in tarinfo.name.split(os.path.sep)):
2509                if verbose: print("Excluding %s due to exclude_dirs" % tarinfo.name)
2510                return None
2511
2512            return tarinfo
2513
2514        back = os.getcwd()
2515        os.chdir(os.path.join(self.workdir, ".."))
2516
2517        import tarfile
2518        name = os.path.basename(self.workdir) + ".tar.gz" if name is None else name
2519        with tarfile.open(name=name, mode='w:gz', **kwargs) as tar:
2520            tar.add(os.path.basename(self.workdir), arcname=None, recursive=True, filter=filter)
2521
2522            # Add the script used to generate the flow.
2523            if self.pyfile is not None and os.path.exists(self.pyfile):
2524                tar.add(self.pyfile)
2525
2526        os.chdir(back)
2527        return name
2528
2529    def get_graphviz(self, engine="automatic", graph_attr=None, node_attr=None, edge_attr=None):
2530        """
2531        Generate flow graph in the DOT language.
2532
2533        Args:
2534            engine: Layout command used. ['dot', 'neato', 'twopi', 'circo', 'fdp', 'sfdp', 'patchwork', 'osage']
2535            graph_attr: Mapping of (attribute, value) pairs for the graph.
2536            node_attr: Mapping of (attribute, value) pairs set for all nodes.
2537            edge_attr: Mapping of (attribute, value) pairs set for all edges.
2538
2539        Returns: graphviz.Digraph <https://graphviz.readthedocs.io/en/stable/api.html#digraph>
2540        """
2541        self.allocate()
2542
2543        from graphviz import Digraph
2544        fg = Digraph("flow", #filename="flow_%s.gv" % os.path.basename(self.relworkdir),
2545            engine="fdp" if engine == "automatic" else engine)
2546
2547        # Set graph attributes. https://www.graphviz.org/doc/info/
2548        fg.attr(label=repr(self))
2549        fg.attr(rankdir="LR", pagedir="BL")
2550        fg.node_attr.update(color='lightblue2', style='filled')
2551
2552        # Add input attributes.
2553        if graph_attr is not None: fg.graph_attr.update(**graph_attr)
2554        if node_attr is not None: fg.node_attr.update(**node_attr)
2555        if edge_attr is not None: fg.edge_attr.update(**edge_attr)
2556
2557        def node_kwargs(node):
2558            return dict(
2559                #shape="circle",
2560                color=node.color_hex,
2561                fontsize="8.0",
2562                label=(str(node) if not hasattr(node, "pos_str") else
2563                    node.pos_str + "\n" + node.__class__.__name__),
2564            )
2565
2566        edge_kwargs = dict(arrowType="vee", style="solid")
2567        cluster_kwargs = dict(rankdir="LR", pagedir="BL", style="rounded", bgcolor="azure2")
2568
2569        for work in self:
2570            # Build cluster with tasks.
2571            cluster_name = "cluster%s" % work.name
2572            with fg.subgraph(name=cluster_name) as wg:
2573                wg.attr(**cluster_kwargs)
2574                wg.attr(label="%s (%s)" % (work.__class__.__name__, work.name))
2575                for task in work:
2576                    wg.node(task.name, **node_kwargs(task))
2577                    # Connect children to task.
2578                    for child in task.get_children():
2579                        # Find file extensions required by this task
2580                        i = [dep.node for dep in child.deps].index(task)
2581                        edge_label = "+".join(child.deps[i].exts)
2582                        fg.edge(task.name, child.name, label=edge_label, color=task.color_hex,
2583                                **edge_kwargs)
2584
2585        # Treat the case in which we have a work producing output for other tasks.
2586        for work in self:
2587            children = work.get_children()
2588            if not children: continue
2589            cluster_name = "cluster%s" % work.name
2590            seen = set()
2591            for child in children:
2592                # This is not needed, too much confusing
2593                #fg.edge(cluster_name, child.name, color=work.color_hex, **edge_kwargs)
2594                # Find file extensions required by work
2595                i = [dep.node for dep in child.deps].index(work)
2596                for ext in child.deps[i].exts:
2597                    out = "%s (%s)" % (ext, work.name)
2598                    fg.node(out)
2599                    fg.edge(out, child.name, **edge_kwargs)
2600                    key = (cluster_name, out)
2601                    if key not in seen:
2602                        seen.add(key)
2603                        fg.edge(cluster_name, out, color=work.color_hex, **edge_kwargs)
2604
2605        # Treat the case in which we have a task that depends on external files.
2606        seen = set()
2607        for task in self.iflat_tasks():
2608            #print(task.get_parents())
2609            for node in (p for p in task.get_parents() if p.is_file):
2610                #print("parent file node", node)
2611                #infile = "%s (%s)" % (ext, work.name)
2612                infile = node.filepath
2613                if infile not in seen:
2614                    seen.add(infile)
2615                    fg.node(infile, **node_kwargs(node))
2616
2617                fg.edge(infile, task.name, color=node.color_hex, **edge_kwargs)
2618
2619        return fg
2620
2621    @add_fig_kwargs
2622    def graphviz_imshow(self, ax=None, figsize=None, dpi=300, fmt="png", **kwargs):
2623        """
2624        Generate flow graph in the DOT language and plot it with matplotlib.
2625
2626        Args:
2627            ax: |matplotlib-Axes| or None if a new figure should be created.
2628            figsize: matplotlib figure size (None to use default)
2629            dpi: DPI value.
2630            fmt: Select format for output image
2631
2632        Return: |matplotlib-Figure|
2633        """
2634        graph = self.get_graphviz(**kwargs)
2635        graph.format = fmt
2636        graph.attr(dpi=str(dpi))
2637        _, tmpname = tempfile.mkstemp()
2638        path = graph.render(tmpname, view=False, cleanup=True)
2639        ax, fig, _ = get_ax_fig_plt(ax=ax, figsize=figsize, dpi=dpi)
2640        import matplotlib.image as mpimg
2641        ax.imshow(mpimg.imread(path, format="png")) #, interpolation="none")
2642        ax.axis("off")
2643
2644        return fig
2645
2646    @add_fig_kwargs
2647    def plot_networkx(self, mode="network", with_edge_labels=False, ax=None, arrows=False,
2648                      node_size="num_cores", node_label="name_class", layout_type="spring", **kwargs):
2649        """
2650        Use networkx to draw the flow with the connections among the nodes and
2651        the status of the tasks.
2652
2653        Args:
2654            mode: `networkx` to show connections, `status` to group tasks by status.
2655            with_edge_labels: True to draw edge labels.
2656            ax: |matplotlib-Axes| or None if a new figure should be created.
2657            arrows: if True draw arrowheads.
2658            node_size: By default, the size of the node is proportional to the number of cores used.
2659            node_label: By default, the task class is used to label node.
2660            layout_type: Get positions for all nodes using `layout_type`. e.g. pos = nx.spring_layout(g)
2661
2662        .. warning::
2663
2664            Requires networkx package.
2665        """
2666        self.allocate()
2667
2668        # Build the graph
2669        import networkx as nx
2670        g = nx.Graph() if not arrows else nx.DiGraph()
2671        edge_labels = {}
2672        for task in self.iflat_tasks():
2673            g.add_node(task, name=task.name)
2674            for child in task.get_children():
2675                g.add_edge(task, child)
2676                # TODO: Add getters! What about locked nodes!
2677                i = [dep.node for dep in child.deps].index(task)
2678                edge_labels[(task, child)] = " ".join(child.deps[i].exts)
2679
2680            filedeps = [d for d in task.deps if d.node.is_file]
2681            for d in filedeps:
2682                #print(d.node, d.exts)
2683                g.add_node(d.node, name="%s (%s)" % (d.node.basename, d.node.node_id))
2684                g.add_edge(d.node, task)
2685                edge_labels[(d.node, task)] = "+".join(d.exts)
2686
2687        # This part is needed if we have a work that produces output used by other nodes.
2688        for work in self:
2689            children = work.get_children()
2690            if not children:
2691                continue
2692            g.add_node(work, name=work.name)
2693            for task in work:
2694                g.add_edge(task, work)
2695                edge_labels[(task, work)] = "all_ok "
2696            for child in children:
2697                g.add_edge(work, child)
2698                i = [dep.node for dep in child.deps].index(work)
2699                edge_labels[(work, child)] = "+".join(child.deps[i].exts)
2700
2701        # Get positions for all nodes using layout_type.
2702        # e.g. pos = nx.spring_layout(g)
2703        pos = getattr(nx, layout_type + "_layout")(g)
2704
2705        # Select function used to compute the size of the node
2706        def make_node_size(node):
2707            if node.is_task:
2708                return 300 * node.manager.num_cores
2709            else:
2710                return 600
2711
2712        # Function used to build the label
2713        def make_node_label(node):
2714            if node_label == "name_class":
2715                if node.is_file:
2716                    return "%s\n(%s)" % (node.basename, node.node_id)
2717                else:
2718                    return (node.pos_str + "\n" + node.__class__.__name__
2719                            if hasattr(node, "pos_str") else str(node))
2720            else:
2721                raise NotImplementedError("node_label: %s" % str(node_label))
2722
2723        labels = {node: make_node_label(node) for node in g.nodes()}
2724        ax, fig, plt = get_ax_fig_plt(ax=ax)
2725
2726        # Select plot type.
2727        if mode == "network":
2728            nx.draw_networkx(g, pos, labels=labels,
2729                             node_color=[node.color_rgb for node in g.nodes()],
2730                             node_size=[make_node_size(node) for node in g.nodes()],
2731                             width=1, style="dotted", with_labels=True, arrows=arrows, ax=ax)
2732
2733            # Draw edge labels
2734            if with_edge_labels:
2735                nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels, ax=ax)
2736
2737        elif mode == "status":
2738            # Group tasks by status (only tasks are show here).
2739            for status in self.ALL_STATUS:
2740                tasks = list(self.iflat_tasks(status=status))
2741
2742                # Draw nodes (color is given by status)
2743                node_color = status.color_opts["color"]
2744                if node_color is None: node_color = "black"
2745                #print("num nodes %s with node_color %s" % (len(tasks), node_color))
2746
2747                nx.draw_networkx_nodes(g, pos,
2748                                       nodelist=tasks,
2749                                       node_color=node_color,
2750                                       node_size=[make_node_size(task) for task in tasks],
2751                                       alpha=0.5, ax=ax
2752                                       #label=str(status),
2753                                       )
2754            # Draw edges.
2755            nx.draw_networkx_edges(g, pos, width=2.0, alpha=0.5, arrows=arrows, ax=ax) # edge_color='r')
2756
2757            # Draw labels
2758            nx.draw_networkx_labels(g, pos, labels, font_size=12, ax=ax)
2759
2760            # Draw edge labels
2761            if with_edge_labels:
2762                nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels, ax=ax)
2763                #label_pos=0.5, font_size=10, font_color='k', font_family='sans-serif', font_weight='normal',
2764                # alpha=1.0, bbox=None, ax=None, rotate=True, **kwds)
2765
2766        else:
2767            raise ValueError("Unknown value for mode: %s" % str(mode))
2768
2769        ax.axis("off")
2770        return fig
2771
2772    def write_open_notebook(flow, foreground):
2773        """
2774        Generate an ipython notebook and open it in the browser.
2775        Return system exit code.
2776        """
2777        import nbformat
2778        nbf = nbformat.v4
2779        nb = nbf.new_notebook()
2780
2781        nb.cells.extend([
2782            #nbf.new_markdown_cell("This is an auto-generated notebook for %s" % os.path.basename(pseudopath)),
2783            nbf.new_code_cell("""\
2784    import sys, os
2785    import numpy as np
2786
2787    %matplotlib notebook
2788    from IPython.display import display
2789
2790    # This to render pandas DataFrames with https://github.com/quantopian/qgrid
2791    #import qgrid
2792    #qgrid.nbinstall(overwrite=True)  # copies javascript dependencies to your /nbextensions folder
2793
2794    from abipy import abilab
2795
2796    # Tell AbiPy we are inside a notebook and use seaborn settings for plots.
2797    # See https://seaborn.pydata.org/generated/seaborn.set.html#seaborn.set
2798    abilab.enable_notebook(with_seaborn=True)
2799    """),
2800
2801            nbf.new_code_cell("flow = abilab.Flow.pickle_load('%s')" % flow.workdir),
2802            nbf.new_code_cell("if flow.num_errored_tasks: flow.debug()"),
2803            nbf.new_code_cell("flow.check_status(show=True, verbose=0)"),
2804            nbf.new_code_cell("flow.show_dependencies()"),
2805            nbf.new_code_cell("flow.plot_networkx();"),
2806            nbf.new_code_cell("#flow.get_graphviz();"),
2807            nbf.new_code_cell("flow.show_inputs(nids=None, wslice=None)"),
2808            nbf.new_code_cell("flow.show_history()"),
2809            nbf.new_code_cell("flow.show_corrections()"),
2810            nbf.new_code_cell("flow.show_event_handlers()"),
2811            nbf.new_code_cell("flow.inspect(nids=None, wslice=None)"),
2812            nbf.new_code_cell("flow.show_abierrors()"),
2813            nbf.new_code_cell("flow.show_qouts()"),
2814        ])
2815
2816        import tempfile, io
2817        _, nbpath = tempfile.mkstemp(suffix='.ipynb', text=True)
2818
2819        with io.open(nbpath, 'wt', encoding="utf8") as fh:
2820            nbformat.write(nb, fh)
2821
2822        from monty.os.path import which
2823        has_jupyterlab = which("jupyter-lab") is not None
2824        appname = "jupyter-lab" if has_jupyterlab else "jupyter notebook"
2825        if not has_jupyterlab:
2826            if which("jupyter") is None:
2827                raise RuntimeError("Cannot find jupyter in $PATH. Install it with `pip install`")
2828
2829        appname = "jupyter-lab" if has_jupyterlab else "jupyter notebook"
2830
2831        if foreground:
2832            return os.system("%s %s" % (appname, nbpath))
2833        else:
2834            fd, tmpname = tempfile.mkstemp(text=True)
2835            print(tmpname)
2836            cmd = "%s %s" % (appname, nbpath)
2837            print("Executing:", cmd, "\nstdout and stderr redirected to %s" % tmpname)
2838            import subprocess
2839            process = subprocess.Popen(cmd.split(), shell=False, stdout=fd, stderr=fd)
2840            cprint("pid: %s" % str(process.pid), "yellow")
2841            return process.returncode
2842
2843
2844class G0W0WithQptdmFlow(Flow):
2845
2846    def __init__(self, workdir, scf_input, nscf_input, scr_input, sigma_inputs, manager=None):
2847        """
2848        Build a :class:`Flow` for one-shot G0W0 calculations.
2849        The computation of the q-points for the screening is parallelized with qptdm
2850        i.e. we run independent calculations for each q-point and then we merge the final results.
2851
2852        Args:
2853            workdir: Working directory.
2854            scf_input: Input for the GS SCF run.
2855            nscf_input: Input for the NSCF run (band structure run).
2856            scr_input: Input for the SCR run.
2857            sigma_inputs: Input(s) for the SIGMA run(s).
2858            manager: |TaskManager| object used to submit the jobs. Initialized from manager.yml if manager is None.
2859        """
2860        super().__init__(workdir, manager=manager)
2861
2862        # Register the first work (GS + NSCF calculation)
2863        bands_work = self.register_work(BandStructureWork(scf_input, nscf_input))
2864
2865        # Register the callback that will be executed the work for the SCR with qptdm.
2866        scr_work = self.register_work_from_cbk(cbk_name="cbk_qptdm_workflow", cbk_data={"input": scr_input},
2867                                               deps={bands_work.nscf_task: "WFK"}, work_class=QptdmWork)
2868
2869        # The last work contains a list of SIGMA tasks
2870        # that will use the data produced in the previous two works.
2871        if not isinstance(sigma_inputs, (list, tuple)):
2872            sigma_inputs = [sigma_inputs]
2873
2874        sigma_work = Work()
2875        for sigma_input in sigma_inputs:
2876            sigma_work.register_sigma_task(sigma_input, deps={bands_work.nscf_task: "WFK", scr_work: "SCR"})
2877        self.register_work(sigma_work)
2878
2879        self.allocate()
2880
2881    def cbk_qptdm_workflow(self, cbk):
2882        """
2883        This callback is executed by the flow when bands_work.nscf_task reaches S_OK.
2884
2885        It computes the list of q-points for the W(q,G,G'), creates nqpt tasks
2886        in the second work (QptdmWork), and connect the signals.
2887        """
2888        scr_input = cbk.data["input"]
2889        # Use the WFK file produced by the second
2890        # Task in the first Work (NSCF step).
2891        nscf_task = self[0][1]
2892        wfk_file = nscf_task.outdir.has_abiext("WFK")
2893
2894        work = self[1]
2895        work.set_manager(self.manager)
2896        work.create_tasks(wfk_file, scr_input)
2897        work.add_deps(cbk.deps)
2898
2899        work.set_flow(self)
2900        # Each task has a reference to its work.
2901        for task in work:
2902            task.set_work(work)
2903            # Add the garbage collector.
2904            if self.gc is not None: task.set_gc(self.gc)
2905
2906        work.connect_signals()
2907        work.build()
2908
2909        return work
2910
2911
2912class FlowCallbackError(Exception):
2913    """Exceptions raised by FlowCallback."""
2914
2915
2916class FlowCallback(object):
2917    """
2918    This object implements the callbacks executed by the |Flow| when
2919    particular conditions are fulfilled. See on_dep_ok method of |Flow|.
2920
2921    .. note::
2922
2923        I decided to implement callbacks via this object instead of a standard
2924        approach based on bound methods because:
2925
2926            1) pickle (v<=3) does not support the pickling/unplickling of bound methods
2927
2928            2) There's some extra logic and extra data needed for the proper functioning
2929               of a callback at the flow level and this object provides an easy-to-use interface.
2930    """
2931    Error = FlowCallbackError
2932
2933    def __init__(self, func_name, flow, deps, cbk_data):
2934        """
2935        Args:
2936            func_name: String with the name of the callback to execute.
2937                       func_name must be a bound method of flow with signature:
2938
2939                            func_name(self, cbk)
2940
2941                       where self is the Flow instance and cbk is the callback
2942            flow: Reference to the |Flow|.
2943            deps: List of dependencies associated to the callback
2944                  The callback is executed when all dependencies reach S_OK.
2945            cbk_data: Dictionary with additional data that will be passed to the callback via self.
2946        """
2947        self.func_name = func_name
2948        self.flow = flow
2949        self.deps = deps
2950        self.data = cbk_data or {}
2951        self._disabled = False
2952
2953    def __str__(self):
2954        return "%s: %s bound to %s" % (self.__class__.__name__, self.func_name, self.flow)
2955
2956    def __call__(self):
2957        """Execute the callback."""
2958        if self.can_execute():
2959            # Get the bound method of the flow from func_name.
2960            # We use this trick because pickle (format <=3) does not support bound methods.
2961            try:
2962                func = getattr(self.flow, self.func_name)
2963            except AttributeError as exc:
2964                raise self.Error(str(exc))
2965
2966            return func(self)
2967
2968        else:
2969            raise self.Error("You tried to __call_ a callback that cannot be executed!")
2970
2971    def can_execute(self):
2972        """True if we can execute the callback."""
2973        return not self._disabled and all(dep.status == dep.node.S_OK for dep in self.deps)
2974
2975    def disable(self):
2976        """
2977        True if the callback has been disabled. This usually happens when the callback has been executed.
2978        """
2979        self._disabled = True
2980
2981    def enable(self):
2982        """Enable the callback"""
2983        self._disabled = False
2984
2985    def handle_sender(self, sender):
2986        """
2987        True if the callback is associated to the sender
2988        i.e. if the node who sent the signal appears in the
2989        dependencies of the callback.
2990        """
2991        return sender in [d.node for d in self.deps]
2992
2993
2994# Factory functions.
2995def bandstructure_flow(workdir, scf_input, nscf_input, dos_inputs=None, manager=None, flow_class=Flow, allocate=True):
2996    """
2997    Build a |Flow| for band structure calculations.
2998
2999    Args:
3000        workdir: Working directory.
3001        scf_input: Input for the GS SCF run.
3002        nscf_input: Input for the NSCF run (band structure run).
3003        dos_inputs: Input(s) for the NSCF run (dos run).
3004        manager: |TaskManager| object used to submit the jobs. Initialized from manager.yml if manager is None.
3005        flow_class: Flow subclass
3006        allocate: True if the flow should be allocated before returning.
3007
3008    Returns: |Flow| object
3009    """
3010    flow = flow_class(workdir, manager=manager)
3011    work = BandStructureWork(scf_input, nscf_input, dos_inputs=dos_inputs)
3012    flow.register_work(work)
3013
3014    # Handy aliases
3015    flow.scf_task, flow.nscf_task, flow.dos_tasks = work.scf_task, work.nscf_task, work.dos_tasks
3016
3017    if allocate: flow.allocate()
3018    return flow
3019
3020
3021def g0w0_flow(workdir, scf_input, nscf_input, scr_input, sigma_inputs, manager=None, flow_class=Flow, allocate=True):
3022    """
3023    Build a |Flow| for one-shot $G_0W_0$ calculations.
3024
3025    Args:
3026        workdir: Working directory.
3027        scf_input: Input for the GS SCF run.
3028        nscf_input: Input for the NSCF run (band structure run).
3029        scr_input: Input for the SCR run.
3030        sigma_inputs: List of inputs for the SIGMA run.
3031        flow_class: Flow class
3032        manager: |TaskManager| object used to submit the jobs. Initialized from manager.yml if manager is None.
3033        allocate: True if the flow should be allocated before returning.
3034
3035    Returns: |Flow| object
3036    """
3037    flow = flow_class(workdir, manager=manager)
3038    work = G0W0Work(scf_input, nscf_input, scr_input, sigma_inputs)
3039    flow.register_work(work)
3040    if allocate: flow.allocate()
3041    return flow
3042
3043
3044class PhononFlow(Flow):
3045    """
3046    This Flow provides a high-level interface to compute phonons with DFPT
3047    The flow consists of
3048
3049    1) One workflow for the GS run.
3050
3051    2) nqpt works for phonon calculations. Each work contains
3052       nirred tasks where nirred is the number of irreducible phonon perturbations
3053       for that particular q-point.
3054
3055    .. note:
3056
3057        For a much more flexible interface, use the DFPT works defined in works.py
3058        For instance, EPH calculations are much easier to implement by connecting a single
3059        work that computes all the q-points with the EPH tasks instead of using PhononFlow.
3060    """
3061    @classmethod
3062    def from_scf_input(cls, workdir, scf_input, ph_ngqpt, with_becs=True, manager=None, allocate=True):
3063        """
3064        Create a `PhononFlow` for phonon calculations from an `AbinitInput` defining a ground-state run.
3065
3066        Args:
3067            workdir: Working directory of the flow.
3068            scf_input: |AbinitInput| object with the parameters for the GS-SCF run.
3069            ph_ngqpt: q-mesh for phonons. Must be a sub-mesh of the k-mesh used for
3070                electrons. e.g if ngkpt = (8, 8, 8). ph_ngqpt = (4, 4, 4) is a valid choice
3071                whereas ph_ngqpt = (3, 3, 3) is not!
3072            with_becs: True if Born effective charges are wanted.
3073            manager: |TaskManager| object. Read from `manager.yml` if None.
3074            allocate: True if the flow should be allocated before returning.
3075
3076        Return:
3077            :class:`PhononFlow` object.
3078        """
3079        flow = cls(workdir, manager=manager)
3080
3081        # Register the SCF task
3082        flow.register_scf_task(scf_input)
3083        scf_task = flow[0][0]
3084
3085        # Make sure k-mesh and q-mesh are compatible.
3086        scf_ngkpt, ph_ngqpt = np.array(scf_input["ngkpt"]), np.array(ph_ngqpt)
3087
3088        if any(scf_ngkpt % ph_ngqpt != 0):
3089            raise ValueError("ph_ngqpt %s should be a sub-mesh of scf_ngkpt %s" % (ph_ngqpt, scf_ngkpt))
3090
3091        # Get the q-points in the IBZ from Abinit
3092        qpoints = scf_input.abiget_ibz(ngkpt=ph_ngqpt, shiftk=(0, 0, 0), kptopt=1).points
3093
3094        # Create a PhononWork for each q-point. Add DDK and E-field if q == Gamma and with_becs.
3095        for qpt in qpoints:
3096            if np.allclose(qpt, 0) and with_becs:
3097                ph_work = BecWork.from_scf_task(scf_task)
3098            else:
3099                ph_work = PhononWork.from_scf_task(scf_task, qpoints=qpt)
3100
3101            flow.register_work(ph_work)
3102
3103        if allocate: flow.allocate()
3104
3105        return flow
3106
3107    def open_final_ddb(self):
3108        """
3109        Open the DDB file located in the output directory of the flow.
3110
3111        Return:
3112            :class:`DdbFile` object, None if file could not be found or file is not readable.
3113        """
3114        ddb_path = self.outdir.has_abiext("DDB")
3115        if not ddb_path:
3116            if self.status == self.S_OK:
3117                self.history.critical("%s reached S_OK but didn't produce a GSR file in %s" % (self, self.outdir))
3118            return None
3119
3120        from abipy.dfpt.ddb import DdbFile
3121        try:
3122            return DdbFile(ddb_path)
3123        except Exception as exc:
3124            self.history.critical("Exception while reading DDB file at %s:\n%s" % (ddb_path, str(exc)))
3125            return None
3126
3127    def finalize(self):
3128        """This method is called when the flow is completed."""
3129        # Merge all the out_DDB files found in work.outdir.
3130        ddb_files = list(filter(None, [work.outdir.has_abiext("DDB") for work in self]))
3131
3132        # Final DDB file will be produced in the outdir of the work.
3133        out_ddb = self.outdir.path_in("out_DDB")
3134        desc = "DDB file merged by %s on %s" % (self.__class__.__name__, time.asctime())
3135
3136        mrgddb = wrappers.Mrgddb(manager=self.manager, verbose=0)
3137        mrgddb.merge(self.outdir.path, ddb_files, out_ddb=out_ddb, description=desc)
3138        print("Final DDB file available at %s" % out_ddb)
3139
3140        # Call the method of the super class.
3141        retcode = super().finalize()
3142        return retcode
3143
3144
3145class NonLinearCoeffFlow(Flow):
3146    """
3147    1) One workflow for the GS run.
3148
3149    2) nqpt works for electric field calculations. Each work contains
3150       nirred tasks where nirred is the number of irreducible perturbations
3151       for that particular q-point.
3152    """
3153    @classmethod
3154    def from_scf_input(cls, workdir, scf_input, manager=None, allocate=True):
3155        """
3156        Create a `NonlinearFlow` for second order susceptibility calculations from
3157        an `AbinitInput` defining a ground-state run.
3158
3159        Args:
3160            workdir: Working directory of the flow.
3161            scf_input: |AbinitInput| object with the parameters for the GS-SCF run.
3162            manager: |TaskManager| object. Read from `manager.yml` if None.
3163            allocate: True if the flow should be allocated before returning.
3164
3165        Return:
3166            :class:`NonlinearFlow` object.
3167        """
3168        flow = cls(workdir, manager=manager)
3169
3170        flow.register_scf_task(scf_input)
3171        scf_task = flow[0][0]
3172
3173        nl_work = DteWork.from_scf_task(scf_task)
3174
3175        flow.register_work(nl_work)
3176
3177        if allocate: flow.allocate()
3178
3179        return flow
3180
3181    def open_final_ddb(self):
3182        """
3183        Open the DDB file located in the output directory of the flow.
3184
3185        Return:
3186            |DdbFile| object, None if file could not be found or file is not readable.
3187        """
3188        ddb_path = self.outdir.has_abiext("DDB")
3189        if not ddb_path:
3190            if self.status == self.S_OK:
3191                self.history.critical("%s reached S_OK but didn't produce a GSR file in %s" % (self, self.outdir))
3192            return None
3193
3194        from abipy.dfpt.ddb import DdbFile
3195        try:
3196            return DdbFile(ddb_path)
3197        except Exception as exc:
3198            self.history.critical("Exception while reading DDB file at %s:\n%s" % (ddb_path, str(exc)))
3199            return None
3200
3201    def finalize(self):
3202        """This method is called when the flow is completed."""
3203        # Merge all the out_DDB files found in work.outdir.
3204        ddb_files = list(filter(None, [work.outdir.has_abiext("DDB") for work in self]))
3205
3206        # Final DDB file will be produced in the outdir of the work.
3207        out_ddb = self.outdir.path_in("out_DDB")
3208        desc = "DDB file merged by %s on %s" % (self.__class__.__name__, time.asctime())
3209
3210        mrgddb = wrappers.Mrgddb(manager=self.manager, verbose=0)
3211        mrgddb.merge(self.outdir.path, ddb_files, out_ddb=out_ddb, description=desc)
3212
3213        print("Final DDB file available at %s" % out_ddb)
3214
3215        # Call the method of the super class.
3216        retcode = super().finalize()
3217        print("retcode", retcode)
3218        #if retcode != 0: return retcode
3219        return retcode
3220
3221
3222def phonon_conv_flow(workdir, scf_input, qpoints, params, manager=None, allocate=True):
3223    """
3224    Create a |Flow| to perform convergence studies for phonon calculations.
3225
3226    Args:
3227        workdir: Working directory of the flow.
3228        scf_input: |AbinitInput| object defining a GS-SCF calculation.
3229        qpoints: List of list of lists with the reduced coordinates of the q-point(s).
3230        params:
3231            To perform a converge study wrt ecut: params=["ecut", [2, 4, 6]]
3232        manager: |TaskManager| object responsible for the submission of the jobs.
3233            If manager is None, the object is initialized from the yaml file
3234            located either in the working directory or in the user configuration dir.
3235        allocate: True if the flow should be allocated before returning.
3236
3237    Return: |Flow| object.
3238    """
3239    qpoints = np.reshape(qpoints, (-1, 3))
3240
3241    flow = Flow(workdir=workdir, manager=manager)
3242
3243    for qpt in qpoints:
3244        for gs_inp in scf_input.product(*params):
3245            # Register the SCF task
3246            work = flow.register_scf_task(gs_inp)
3247
3248            # Add the PhononWork connected to this scf_task.
3249            flow.register_work(PhononWork.from_scf_task(work[0], qpoints=qpt))
3250
3251    if allocate: flow.allocate()
3252    return flow
3253