1# coding: utf-8
2"""Tools and helper functions for abinit calculations"""
3import os
4import re
5import collections
6import shutil
7import operator
8import numpy as np
9
10from fnmatch import fnmatch
11from monty.collections import dict2namedtuple
12from monty.string import list_strings
13from monty.fnmatch import WildCard
14from monty.shutil import copy_r
15from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt
16
17import logging
18logger = logging.getLogger(__name__)
19
20
21def as_bool(s):
22    """
23    Convert a string into a boolean.
24
25    >>> assert as_bool(True) is True and as_bool("Yes") is True and as_bool("false") is False
26    """
27    if s in (False, True): return s
28    # Assume string
29    s = s.lower()
30    if s in ("yes", "true"):
31        return True
32    elif s in ("no", "false"):
33        return False
34    else:
35        raise ValueError("Don't know how to convert type %s: %s into a boolean" % (type(s), s))
36
37
38class File(object):
39    """
40    Very simple class used to store file basenames, absolute paths and directory names.
41    Provides wrappers for the most commonly used functions defined in os.path.
42    """
43    def __init__(self, path):
44        self._path = os.path.abspath(path)
45
46    def __repr__(self):
47        return "<%s at %s, %s>" % (self.__class__.__name__, id(self), self.path)
48
49    def __str__(self):
50        return "<%s, %s>" % (self.__class__.__name__, self.path)
51
52    def __eq__(self, other):
53        return False if other is None else self.path == other.path
54
55    def __ne__(self, other):
56        return not self.__eq__(other)
57
58    @property
59    def path(self):
60        """Absolute path of the file."""
61        return self._path
62
63    @property
64    def basename(self):
65        """File basename."""
66        return os.path.basename(self.path)
67
68    @property
69    def relpath(self):
70        """Relative path."""
71        try:
72            return os.path.relpath(self.path)
73        except OSError:
74            # current working directory may not be defined!
75            return self.path
76
77    @property
78    def dirname(self):
79        """Absolute path of the directory where the file is located."""
80        return os.path.dirname(self.path)
81
82    @property
83    def exists(self):
84        """True if file exists."""
85        return os.path.exists(self.path)
86
87    @property
88    def isncfile(self):
89        """True if self is a NetCDF file"""
90        return self.basename.endswith(".nc")
91
92    def chmod(self, mode):
93        """Change the access permissions of a file."""
94        os.chmod(self.path, mode)
95
96    def read(self):
97        """Read data from file."""
98        with open(self.path, "r") as f:
99            return f.read()
100
101    def readlines(self):
102        """Read lines from files."""
103        with open(self.path, "r") as f:
104            return f.readlines()
105
106    def write(self, string):
107        """Write string to file."""
108        self.make_dir()
109        with open(self.path, "w") as f:
110            if not string.endswith("\n"):
111                return f.write(string + "\n")
112            else:
113                return f.write(string)
114
115    def writelines(self, lines):
116        """Write a list of strings to file."""
117        self.make_dir()
118        with open(self.path, "w") as f:
119            return f.writelines(lines)
120
121    def make_dir(self):
122        """Make the directory where the file is located."""
123        if not os.path.exists(self.dirname):
124            os.makedirs(self.dirname)
125
126    def remove(self):
127        """Remove the file."""
128        try:
129            os.remove(self.path)
130        except Exception:
131            pass
132
133    def move(self, dst):
134        """
135        Recursively move a file or directory to another location. This is
136        similar to the Unix "mv" command.
137        """
138        shutil.move(self.path, dst)
139
140    def get_stat(self):
141        """Results from os.stat"""
142        return os.stat(self.path)
143
144    def getsize(self):
145        """
146        Return the size, in bytes, of path.
147        Return 0 if the file is empty or it does not exist.
148        """
149        if not self.exists: return 0
150        return os.path.getsize(self.path)
151
152
153class Directory(object):
154    """
155    Very simple class that provides helper functions
156    wrapping the most commonly used functions defined in os.path.
157    """
158    def __init__(self, path):
159        self._path = os.path.abspath(path)
160
161    def __repr__(self):
162        return "<%s at %s, %s>" % (self.__class__.__name__, id(self), self.path)
163
164    def __str__(self):
165        return self.path
166
167    def __eq__(self, other):
168        return False if other is None else self.path == other.path
169
170    def __ne__(self, other):
171        return not self.__eq__(other)
172
173    @property
174    def path(self):
175        """Absolute path of the directory."""
176        return self._path
177
178    @property
179    def relpath(self):
180        """Relative path."""
181        return os.path.relpath(self.path)
182
183    @property
184    def basename(self):
185        """Directory basename."""
186        return os.path.basename(self.path)
187
188    def path_join(self, *p):
189        """
190        Join two or more pathname components, inserting '/' as needed.
191        If any component is an absolute path, all previous path components will be discarded.
192        """
193        return os.path.join(self.path, *p)
194
195    @property
196    def exists(self):
197        """True if file exists."""
198        return os.path.exists(self.path)
199
200    def makedirs(self):
201        """
202        Super-mkdir; create a leaf directory and all intermediate ones.
203        Works like mkdir, except that any intermediate path segment (not
204        just the rightmost) will be created if it does not exist.
205        """
206        if not self.exists:
207            os.makedirs(self.path)
208
209    def rmtree(self):
210        """Recursively delete the directory tree"""
211        shutil.rmtree(self.path, ignore_errors=True)
212
213    def copy_r(self, dst):
214        """
215        Implements a recursive copy function similar to Unix's "cp -r" command.
216        """
217        return copy_r(self.path, dst)
218
219    def clean(self):
220        """Remove all files in the directory tree while preserving the directory"""
221        for path in self.list_filepaths():
222            try:
223                os.remove(path)
224            except Exception:
225                pass
226
227    def path_in(self, file_basename):
228        """Return the absolute path of filename in the directory."""
229        return os.path.join(self.path, file_basename)
230
231    def list_filepaths(self, wildcard=None):
232        """
233        Return the list of absolute filepaths in the directory.
234
235        Args:
236            wildcard: String of tokens separated by "|". Each token represents a pattern.
237                If wildcard is not None, we return only those files whose basename matches
238                the given shell pattern (uses fnmatch).
239                Example:
240                  wildcard="*.nc|*.pdf" selects only those files that end with .nc or .pdf
241        """
242        # Select the files in the directory.
243        fnames = [f for f in os.listdir(self.path)]
244        filepaths = filter(os.path.isfile, [os.path.join(self.path, f) for f in fnames])
245
246        if wildcard is not None:
247            # Filter using shell patterns.
248            w = WildCard(wildcard)
249            filepaths = [path for path in filepaths if w.match(os.path.basename(path))]
250            #filepaths = WildCard(wildcard).filter(filepaths)
251
252        return filepaths
253
254    def has_abiext(self, ext, single_file=True):
255        """
256        Returns the absolute path of the ABINIT file with extension ext.
257        Support both Fortran files and netcdf files. In the later case,
258        we check whether a file with extension ext + ".nc" is present
259        in the directory. Returns empty string is file is not present.
260
261        Raises:
262            `ValueError` if multiple files with the given ext are found.
263            This implies that this method is not compatible with multiple datasets.
264        """
265        if ext != "abo":
266            ext = ext if ext.startswith('_') else '_' + ext
267
268        files = []
269        for f in self.list_filepaths():
270            # For the time being, we ignore DDB files in nc format.
271            if ext == "_DDB" and f.endswith(".nc"): continue
272            # Ignore BSE text files e.g. GW_NLF_MDF
273            if ext == "_MDF" and not f.endswith(".nc"): continue
274            # Ignore DDK.nc files (temporary workaround for v8.8.2 in which
275            # the DFPT code produces a new file with DDK.nc extension that enters
276            # into conflict with AbiPy convention.
277            if ext == "_DDK" and f.endswith(".nc"): continue
278
279            if f.endswith(ext) or f.endswith(ext + ".nc"):
280                files.append(f)
281
282        # This should fix the problem with the 1WF files in which the file extension convention is broken
283        if not files:
284            files = [f for f in self.list_filepaths() if fnmatch(f, "*%s*" % ext)]
285
286        if not files:
287            return ""
288
289        if len(files) > 1 and single_file:
290            # ABINIT users must learn that multiple datasets are bad!
291            raise ValueError("Found multiple files with the same extensions:\n %s\n" % files +
292                             "Please avoid multiple datasets!")
293
294        return files[0] if single_file else files
295
296    def symlink_abiext(self, inext, outext):
297        """
298        Create a simbolic link (outext --> inext). The file names are implicitly
299        given by the ABINIT file extension.
300
301        Example:
302
303            outdir.symlink_abiext('1WF', 'DDK')
304
305        creates the link out_DDK that points to out_1WF
306
307        Return: 0 if success.
308
309        Raise: RuntimeError
310        """
311        infile = self.has_abiext(inext)
312        if not infile:
313            raise RuntimeError('no file with extension `%s` in `%s`' % (inext, self))
314
315        for i in range(len(infile) - 1, -1, -1):
316            if infile[i] == '_':
317                break
318        else:
319            raise RuntimeError('Extension `%s` could not be detected in file `%s`' % (inext, infile))
320
321        outfile = infile[:i] + '_' + outext
322        if infile.endswith(".nc") and not outfile.endswith(".nc"):
323            outfile = outfile + ".nc"
324
325        if os.path.exists(outfile):
326            if os.path.islink(outfile):
327                if os.path.realpath(outfile) == infile:
328                    logger.debug("Link `%s` already exists but it's OK because it points to the correct file" % outfile)
329                    return 0
330                else:
331                    raise RuntimeError("Expecting link at `%s` already exists but it does not point to `%s`" % (outfile, infile))
332            else:
333                raise RuntimeError('Expecting link at `%s` but found file.' % outfile)
334
335        os.symlink(infile, outfile)
336
337        return 0
338
339    def rename_abiext(self, inext, outext):
340        """Rename the Abinit file with extension inext with the new extension outext"""
341        infile = self.has_abiext(inext)
342        if not infile:
343            raise RuntimeError('no file with extension %s in %s' % (inext, self))
344
345        for i in range(len(infile) - 1, -1, -1):
346            if infile[i] == '_':
347                break
348        else:
349            raise RuntimeError('Extension %s could not be detected in file %s' % (inext, infile))
350
351        outfile = infile[:i] + '_' + outext
352        shutil.move(infile, outfile)
353        return 0
354
355    def copy_abiext(self, inext, outext):
356        """Copy the Abinit file with extension inext to a new file with the extension outext"""
357        infile = self.has_abiext(inext)
358        if not infile:
359            raise RuntimeError('no file with extension %s in %s' % (inext, self))
360
361        for i in range(len(infile) - 1, -1, -1):
362            if infile[i] == '_':
363                break
364        else:
365            raise RuntimeError('Extension %s could not be detected in file %s' % (inext, infile))
366
367        outfile = infile[:i] + '_' + outext
368        shutil.copy(infile, outfile)
369        return 0
370
371    def remove_exts(self, exts):
372        """
373        Remove the files with the given extensions. Unlike rmtree, this function preserves the directory path.
374        Return list with the absolute paths of the files that have been removed.
375        """
376        paths = []
377
378        for ext in list_strings(exts):
379            path = self.has_abiext(ext)
380            if not path: continue
381            try:
382                os.remove(path)
383                paths.append(path)
384            except IOError:
385                logger.warning("Exception while trying to remove file %s" % path)
386
387        return paths
388
389    def find_last_timden_file(self):
390        """
391        ABINIT produces lots of out_TIM1_DEN files for each step and we need to find the lat
392        one in order to prepare the restart or to connect other tasks to the structural relaxation.
393
394        This function finds all the TIM?_DEN files in self and return a namedtuple (path, step)
395        where `path` is the path of the last TIM?_DEN file and step is the iteration number.
396        Returns None if the directory does not contain TIM?_DEN files.
397        """
398        regex = re.compile(r"out_TIM(\d+)_DEN(.nc)?$")
399
400        timden_paths = [f for f in self.list_filepaths() if regex.match(os.path.basename(f))]
401        if not timden_paths: return None
402
403        # Build list of (step, path) tuples.
404        stepfile_list = []
405        for path in timden_paths:
406            name = os.path.basename(path)
407            match = regex.match(name)
408            step, ncext = match.groups()
409            stepfile_list.append((int(step), path))
410
411        # DSU sort.
412        last = sorted(stepfile_list, key=lambda t: t[0])[-1]
413        return dict2namedtuple(step=last[0], path=last[1])
414
415    def find_1wf_files(self):
416        """
417        Abinit adds the idir-ipert index at the end of the 1WF file and this breaks the extension
418        e.g. out_1WF4. This method scans the files in the directories and returns a list of namedtuple
419        Each named tuple gives the `path` of the 1FK file and the `pertcase` index.
420        """
421        regex = re.compile(r"out_1WF(\d+)(\.nc)?$")
422
423        wf_paths = [f for f in self.list_filepaths() if regex.match(os.path.basename(f))]
424        if not wf_paths: return None
425
426        # Build list of (pertcase, path) tuples.
427        pertfile_list = []
428        for path in wf_paths:
429            name = os.path.basename(path)
430            match = regex.match(name)
431            pertcase, ncext = match.groups()
432            pertfile_list.append((int(pertcase), path))
433
434        # DSU sort.
435        pertfile_list = sorted(pertfile_list, key=lambda t: t[0])
436        return [dict2namedtuple(pertcase=item[0], path=item[1]) for item in pertfile_list]
437
438    def find_1den_files(self):
439        """
440        Abinit adds the idir-ipert index at the end of the 1DEN file and this breaks the extension
441        e.g. out_DEN1. This method scans the files in the directories and returns a list of namedtuple
442        Each named tuple gives the `path` of the 1DEN file and the `pertcase` index.
443        """
444        regex = re.compile(r"out_DEN(\d+)(\.nc)?$")
445        den_paths = [f for f in self.list_filepaths() if regex.match(os.path.basename(f))]
446        if not den_paths: return None
447
448        # Build list of (pertcase, path) tuples.
449        pertfile_list = []
450        for path in den_paths:
451            name = os.path.basename(path)
452            match = regex.match(name)
453            pertcase, ncext = match.groups()
454            pertfile_list.append((int(pertcase), path))
455
456        # DSU sort.
457        pertfile_list = sorted(pertfile_list, key=lambda t: t[0])
458        return [dict2namedtuple(pertcase=item[0], path=item[1]) for item in pertfile_list]
459
460
461# This dictionary maps ABINIT file extensions to the variables that must be used to read the file in input.
462#
463# TODO: In Abinit9, it's possible to specify absolute paths with e.g., getden_path
464# Now it's possible to avoid creating symbolic links before running but
465# moving to the new approach requires some careful testing besides not all files support the get*_path syntax!
466
467_EXT2VARS = {
468    "DEN": {"irdden": 1},
469    "WFK": {"irdwfk": 1},
470    "WFQ": {"irdwfq": 1},
471    "SCR": {"irdscr": 1},
472    "QPS": {"irdqps": 1},
473    "1WF": {"ird1wf": 1},
474    "1DEN": {"ird1den": 1},
475    "BSR": {"irdbsreso": 1},
476    "BSC": {"irdbscoup": 1},
477    "HAYDR_SAVE": {"irdhaydock": 1},
478    "DDK": {"irdddk": 1},
479    "DDB": {},
480    "DVDB": {},
481    "GKK": {},
482    "DKK": {},
483    "EFMAS.nc": {"irdefmas": 1},
484    # Abinit does not implement getkden and irdkden but relies on irden
485    "KDEN": {},  #{"irdkden": 1},
486    "KERANGE.nc": {"getkerange_filepath": '"indata/in_KERANGE.nc"'},
487}
488
489
490def irdvars_for_ext(ext):
491    """
492    Returns a dictionary with the ABINIT variables
493    that must be used to read the file with extension ext.
494    """
495    return _EXT2VARS[ext].copy()
496
497
498def abi_extensions():
499    """List with all the ABINIT extensions that are registered."""
500    return list(_EXT2VARS.keys())[:]
501
502
503def abi_splitext(filename):
504    """
505    Split the ABINIT extension from a filename.
506    "Extension" are found by searching in an internal database.
507
508    Returns "(root, ext)" where ext is the registered ABINIT extension
509    The final ".nc" is included (if any)
510
511    >>> assert abi_splitext("foo_WFK") == ('foo_', 'WFK')
512    >>> assert abi_splitext("/home/guido/foo_bar_WFK.nc") == ('foo_bar_', 'WFK.nc')
513    """
514    filename = os.path.basename(filename)
515    is_ncfile = False
516    if filename.endswith(".nc"):
517        is_ncfile = True
518        filename = filename[:-3]
519
520    known_extensions = abi_extensions()
521
522    # This algorith fails if we have two files
523    # e.g. HAYDR_SAVE, ANOTHER_HAYDR_SAVE
524    for i in range(len(filename) - 1, -1, -1):
525        ext = filename[i:]
526        if ext in known_extensions:
527            break
528
529    else:
530        raise ValueError("Cannot find a registered extension in %s" % filename)
531
532    root = filename[:i]
533    if is_ncfile: ext += ".nc"
534
535    return root, ext
536
537
538class FilepathFixer(object):
539    """
540    This object modifies the names of particular output files
541    produced by ABINIT so that the file extension is preserved.
542    Having a one-to-one mapping between file extension and data format
543    is indeed fundamental for the correct behaviour of abinit since:
544
545        - We locate the output file by just inspecting the file extension
546
547        - We select the variables that must be added to the input file
548          on the basis of the extension specified by the user during
549          the initialization of the `AbinitFlow`.
550
551    Unfortunately, ABINIT developers like to append extra stuff
552    to the initial extension and therefore we have to call
553    `FilepathFixer` to fix the output files produced by the run.
554
555    Example:
556
557        fixer = FilepathFixer()
558        fixer.fix_paths('/foo/out_1WF17') == {'/foo/out_1WF17': '/foo/out_1WF'}
559        fixer.fix_paths('/foo/out_1WF5.nc') == {'/foo/out_1WF5.nc': '/foo/out_1WF.nc'}
560    """
561    def __init__(self):
562        # dictionary mapping the *official* file extension to
563        # the regular expression used to tokenize the basename of the file
564        # To add a new file it's sufficient to add a new regexp and
565        # a static method _fix_EXTNAME
566        self.regs = regs = {}
567        import re
568        regs["1WF"] = re.compile(r"(\w+_)1WF(\d+)(\.nc)?$")
569        regs["1DEN"] = re.compile(r"(\w+_)1DEN(\d+)(\.nc)?$")
570
571    @staticmethod
572    def _fix_1WF(match):
573        root, pert, ncext = match.groups()
574        if ncext is None: ncext = ""
575        return root + "1WF" + ncext
576
577    @staticmethod
578    def _fix_1DEN(match):
579        root, pert, ncext = match.groups()
580        if ncext is None: ncext = ""
581        return root + "1DEN" + ncext
582
583    def _fix_path(self, path):
584        for ext, regex in self.regs.items():
585            head, tail = os.path.split(path)
586
587            match = regex.match(tail)
588            if match:
589                newtail = getattr(self, "_fix_" + ext)(match)
590                newpath = os.path.join(head, newtail)
591                return newpath, ext
592
593        return None, None
594
595    def fix_paths(self, paths):
596        """
597        Fix the filenames in the iterable paths
598
599        Returns:
600            old2new: Mapping old_path --> new_path
601        """
602        old2new, fixed_exts = {}, []
603
604        for path in list_strings(paths):
605            newpath, ext = self._fix_path(path)
606
607            if newpath is not None:
608                #if ext not in fixed_exts:
609                #    if ext == "1WF": continue
610                #    raise ValueError("Unknown extension %s" % ext)
611                #print(ext, path, fixed_exts)
612                #if ext != '1WF':
613                #    assert ext not in fixed_exts
614                if ext not in fixed_exts:
615                    if ext == "1WF": continue
616                    raise ValueError("Unknown extension %s" % ext)
617                fixed_exts.append(ext)
618                old2new[path] = newpath
619
620        return old2new
621
622
623def _bop_not(obj):
624    """Boolean not."""
625    return not bool(obj)
626
627
628def _bop_and(obj1, obj2):
629    """Boolean and."""
630    return bool(obj1) and bool(obj2)
631
632
633def _bop_or(obj1, obj2):
634    """Boolean or."""
635    return bool(obj1) or bool(obj2)
636
637
638def _bop_divisible(num1, num2):
639    """Return True if num1 is divisible by num2."""
640    return (num1 % num2) == 0.0
641
642
643# Mapping string --> operator.
644_UNARY_OPS = {
645    "$not": _bop_not,
646}
647
648_BIN_OPS = {
649    "$eq": operator.eq,
650    "$ne": operator.ne,
651    "$gt": operator.gt,
652    "$ge": operator.ge,
653    "$gte": operator.ge,
654    "$lt": operator.lt,
655    "$le": operator.le,
656    "$lte": operator.le,
657    "$divisible": _bop_divisible,
658    "$and": _bop_and,
659    "$or":  _bop_or,
660}
661
662
663_ALL_OPS = list(_UNARY_OPS.keys()) + list(_BIN_OPS.keys())
664
665
666def map2rpn(map, obj):
667    """
668    Convert a Mongodb-like dictionary to a RPN list of operands and operators.
669
670    Reverse Polish notation (RPN) is a mathematical notation in which every
671    operator follows all of its operands, e.g.
672
673    3 - 4 + 5 -->   3 4 - 5 +
674
675    >>> d = {2.0: {'$eq': 1.0}}
676    >>> assert map2rpn(d, None) == [2.0, 1.0, '$eq']
677    """
678    rpn = []
679
680    for k, v in map.items():
681
682        if k in _ALL_OPS:
683            if isinstance(v, collections.abc.Mapping):
684                # e.g "$not": {"$gt": "one"}
685                # print("in op_vmap",k, v)
686                values = map2rpn(v, obj)
687                rpn.extend(values)
688                rpn.append(k)
689
690            elif isinstance(v, (list, tuple)):
691                # e.g "$and": [{"$not": {"one": 1.0}}, {"two": {"$lt": 3}}]}
692                # print("in_op_list",k, v)
693                for d in v:
694                    rpn.extend(map2rpn(d, obj))
695
696                rpn.append(k)
697
698            else:
699                # Examples
700                # 1) "$eq"": "attribute_name"
701                # 2) "$eq"": 1.0
702                try:
703                    #print("in_otherv",k, v)
704                    rpn.append(getattr(obj, v))
705                    rpn.append(k)
706
707                except TypeError:
708                    #print("in_otherv, raised",k, v)
709                    rpn.extend([v, k])
710        else:
711            try:
712                k = getattr(obj, k)
713            except TypeError:
714                k = k
715
716            if isinstance(v, collections.abc.Mapping):
717                # "one": {"$eq": 1.0}}
718                values = map2rpn(v, obj)
719                rpn.append(k)
720                rpn.extend(values)
721            else:
722                #"one": 1.0
723                rpn.extend([k, v, "$eq"])
724
725    return rpn
726
727
728def evaluate_rpn(rpn):
729    """
730    Evaluates the RPN form produced my map2rpn.
731
732    Returns: bool
733    """
734    vals_stack = []
735
736    for item in rpn:
737
738        if item in _ALL_OPS:
739            # Apply the operator and push to the task.
740            v2 = vals_stack.pop()
741
742            if item in _UNARY_OPS:
743                res = _UNARY_OPS[item](v2)
744
745            elif item in _BIN_OPS:
746                v1 = vals_stack.pop()
747                res = _BIN_OPS[item](v1, v2)
748            else:
749                raise ValueError("%s not in unary_ops or bin_ops" % str(item))
750
751            vals_stack.append(res)
752
753        else:
754            # Push the operand
755            vals_stack.append(item)
756
757    assert len(vals_stack) == 1
758    assert isinstance(vals_stack[0], bool)
759
760    return vals_stack[0]
761
762
763class Condition(object):
764    """
765    This object receives a dictionary that defines a boolean condition whose syntax is similar
766    to the one used in mongodb (albeit not all the operators available in mongodb are supported here).
767
768    Example:
769
770    $gt: {field: {$gt: value} }
771
772    $gt selects those documents where the value of the field is greater than (i.e. >) the specified value.
773
774    $and performs a logical AND operation on an array of two or more expressions (e.g. <expression1>, <expression2>, etc.)
775    and selects the documents that satisfy all the expressions in the array.
776
777    { $and: [ { <expression1> }, { <expression2> } , ... , { <expressionN> } ] }
778
779    Consider the following example:
780
781    db.inventory.find( { qty: { $gt: 20 } } )
782    This query will select all documents in the inventory collection where the qty field value is greater than 20.
783    Consider the following example:
784
785    db.inventory.find( { qty: { $gt: 20 } } )
786    db.inventory.find({ $and: [ { price: 1.99 }, { qty: { $lt: 20 } }, { sale: true } ] } )
787    """
788    @classmethod
789    def as_condition(cls, obj):
790        """Convert obj into :class:`Condition`"""
791        if isinstance(obj, cls):
792            return obj
793        else:
794            return cls(cmap=obj)
795
796    def __init__(self, cmap=None):
797        self.cmap = {} if cmap is None else cmap
798
799    def __str__(self):
800        return str(self.cmap)
801
802    def __bool__(self):
803        return bool(self.cmap)
804
805    __nonzero__ = __bool__
806
807    def __call__(self, obj):
808        if not self: return True
809        try:
810            return evaluate_rpn(map2rpn(self.cmap, obj))
811        except Exception as exc:
812            logger.warning("Condition(%s) raised Exception:\n %s" % (type(obj), str(exc)))
813            return False
814
815
816class Editor(object):
817    """
818    Wrapper class that calls the editor specified by the user
819    or the one specified in the $EDITOR env variable.
820    """
821    def __init__(self, editor=None):
822        """If editor is None, $EDITOR is used."""
823        self.editor = os.getenv("EDITOR", "vi") if editor is None else str(editor)
824
825    def edit_files(self, fnames, ask_for_exit=True):
826        exit_status = 0
827        for idx, fname in enumerate(fnames):
828            exit_status = self.edit_file(fname)
829            if ask_for_exit and idx != len(fnames)-1 and self.user_wants_to_exit():
830                break
831        return exit_status
832
833    def edit_file(self, fname):
834        from subprocess import call
835        retcode = call([self.editor, fname])
836
837        if retcode != 0:
838            import warnings
839            warnings.warn("Error while trying to edit file: %s" % fname)
840
841        return retcode
842
843    @staticmethod
844    def user_wants_to_exit():
845        """Show an interactive prompt asking if exit is wanted."""
846        # Fix python 2.x.
847        try:
848            answer = input("Do you want to continue [Y/n]")
849        except EOFError:
850            return True
851
852        return answer.lower().strip() in ["n", "no"]
853
854
855class SparseHistogram(object):
856
857    def __init__(self, items, key=None, num=None, step=None):
858        if num is None and step is None:
859            raise ValueError("Either num or step must be specified")
860
861        from collections import defaultdict
862
863        values = [key(item) for item in items] if key is not None else items
864        start, stop = min(values), max(values)
865        if num is None:
866            num = int((stop - start) / step)
867            if num == 0: num = 1
868        mesh = np.linspace(start, stop, num, endpoint=False)
869
870        from monty.bisect import find_le
871
872        hist = defaultdict(list)
873        for item, value in zip(items, values):
874            # Find rightmost value less than or equal to x.
875            # hence each bin contains all items whose value is >= value
876            pos = find_le(mesh, value)
877            hist[mesh[pos]].append(item)
878
879        #new = OrderedDict([(pos, hist[pos]) for pos in sorted(hist.keys(), reverse=reverse)])
880        self.binvals = sorted(hist.keys())
881        self.values = [hist[pos] for pos in self.binvals]
882        self.start, self.stop, self.num = start, stop, num
883
884    @add_fig_kwargs
885    def plot(self, ax=None, **kwargs):
886        """
887        Plot the histogram with matplotlib, returns `matplotlib` figure.
888        """
889        ax, fig, plt = get_ax_fig_plt(ax)
890
891        yy = [len(v) for v in self.values]
892        ax.plot(self.binvals, yy, **kwargs)
893
894        return fig
895
896
897class Dirviz(object):
898
899    #file_color = np.array((255, 0, 0)) / 255
900    #dir_color = np.array((0, 0, 255)) / 255
901
902    def __init__(self, top):
903        #if not os.path.isdir(top):
904        #    raise TypeError("%s should be a directory!" % str(top))
905        self.top = os.path.abspath(top)
906
907    def get_cluster_graph(self, engine="fdp", graph_attr=None, node_attr=None, edge_attr=None):
908        """
909        Generate directory graph in the DOT language. Directories are shown as clusters
910
911        .. warning::
912
913            This function scans the entire directory tree starting from top so the resulting
914            graph can be really big.
915
916        Args:
917            engine: Layout command used. ['dot', 'neato', 'twopi', 'circo', 'fdp', 'sfdp', 'patchwork', 'osage']
918            graph_attr: Mapping of (attribute, value) pairs for the graph.
919            node_attr: Mapping of (attribute, value) pairs set for all nodes.
920            edge_attr: Mapping of (attribute, value) pairs set for all edges.
921
922        Returns: graphviz.Digraph <https://graphviz.readthedocs.io/en/stable/api.html#digraph>
923        """
924        # https://www.graphviz.org/doc/info/
925        from graphviz import Digraph
926        g = Digraph("directory", #filename="flow_%s.gv" % os.path.basename(self.relworkdir),
927            engine=engine) # if engine == "automatic" else engine)
928
929        # Set graph attributes.
930        #g.attr(label="%s@%s" % (self.__class__.__name__, self.relworkdir))
931        g.attr(label=self.top)
932        #g.attr(fontcolor="white", bgcolor='purple:pink')
933        #g.attr(rankdir="LR", pagedir="BL")
934        #g.attr(constraint="false", pack="true", packMode="clust")
935        g.node_attr.update(color='lightblue2', style='filled')
936        #g.node_attr.update(ranksep='equally')
937
938        # Add input attributes.
939        if graph_attr is not None:
940            g.graph_attr.update(**graph_attr)
941        if node_attr is not None:
942            g.node_attr.update(**node_attr)
943        if edge_attr is not None:
944            g.edge_attr.update(**edge_attr)
945
946        def node_kwargs(path):
947            return dict(
948                #shape="circle",
949                #shape="none",
950                #shape="plaintext",
951                #shape="point",
952                shape="record",
953                #color=node.color_hex,
954                fontsize="8.0",
955                label=os.path.basename(path),
956            )
957
958        edge_kwargs = dict(arrowType="vee", style="solid", minlen="1")
959        cluster_kwargs = dict(rankdir="LR", pagedir="BL", style="rounded", bgcolor="azure2")
960
961        # TODO: Write other method without clusters if not walk.
962        exclude_top_node = False
963        for root, dirs, files in os.walk(self.top):
964            if exclude_top_node and root == self.top: continue
965            cluster_name = "cluster_%s" % root
966            #print("root", root, cluster_name, "dirs", dirs, "files", files, sep="\n")
967
968            with g.subgraph(name=cluster_name) as d:
969                d.attr(**cluster_kwargs)
970                d.attr(rank="source" if (files or dirs) else "sink")
971                d.attr(label=os.path.basename(root))
972                for f in files:
973                    filepath = os.path.join(root, f)
974                    d.node(filepath, **node_kwargs(filepath))
975                    if os.path.islink(filepath):
976                        # Follow the link and use the relpath wrt link as label.
977                        realp = os.path.realpath(filepath)
978                        realp = os.path.relpath(realp, filepath)
979                        #realp = os.path.relpath(realp, self.top)
980                        #print(filepath, realp)
981                        #g.node(realp, **node_kwargs(realp))
982                        g.edge(filepath, realp, **edge_kwargs)
983
984                for dirname in dirs:
985                    dirpath = os.path.join(root, dirname)
986                    #head, basename = os.path.split(dirpath)
987                    new_cluster_name = "cluster_%s" % dirpath
988                    #rank = "source" if os.listdir(dirpath) else "sink"
989                    #g.node(dirpath, rank=rank, **node_kwargs(dirpath))
990                    #g.edge(dirpath, new_cluster_name, **edge_kwargs)
991                    #d.edge(cluster_name, new_cluster_name, minlen="2", **edge_kwargs)
992                    d.edge(cluster_name, new_cluster_name, **edge_kwargs)
993        return g
994