1# coding: utf-8
2# flake8: noqa
3"""
4Common test support for all AbiPy test scripts.
5
6This single module should provide all the common functionality for abipy tests
7in a single location, so that test scripts can just import it and work right away.
8"""
9import os
10import numpy
11import subprocess
12import json
13import tempfile
14import unittest
15try:
16    import numpy.testing as nptu
17except ImportError:
18    import numpy.testing.utils as nptu
19import abipy.data as abidata
20
21from functools import wraps
22from monty.os.path import which
23from monty.string import is_string
24from pymatgen.util.testing import PymatgenTest
25
26import logging
27logger = logging.getLogger(__file__)
28
29root = os.path.dirname(__file__)
30
31__all__ = [
32    "AbipyTest",
33]
34
35
36def cmp_version(this, other, op=">="):
37    """
38    Compare two version strings with the given operator ``op``
39    >>> assert cmp_version("1.1.1", "1.1.0") and not cmp_version("1.1.1", "1.1.0", op="==")
40    """
41    from pkg_resources import parse_version
42    from monty.operator import operator_from_str
43    op = operator_from_str(op)
44    return op(parse_version(this), parse_version(other))
45
46
47def has_abinit(version=None, op=">=", manager=None):
48    """
49    True if abinit is available via TaskManager configuration options.
50    If version is not None, `abinit_version op version` is evaluated and the result is returned.
51    """
52    from abipy.flowtk import TaskManager, AbinitBuild
53    manager = TaskManager.from_user_config() if manager is None else manager
54    build = AbinitBuild(manager=manager)
55    if version is None:
56        return build.version != "0.0.0"
57    else:
58        return cmp_version(build.version, version, op=op)
59
60
61_HAS_MATPLOTLIB_CALLS = 0
62
63
64def has_matplotlib(version=None, op=">="):
65    """
66    True if matplotlib_ is installed.
67    If version is None, the result of matplotlib.__version__ `op` version is returned.
68    """
69    try:
70        import matplotlib
71        # have_display = "DISPLAY" in os.environ
72    except ImportError:
73        print("Skipping matplotlib test")
74        return False
75
76    global _HAS_MATPLOTLIB_CALLS
77    _HAS_MATPLOTLIB_CALLS += 1
78
79    if _HAS_MATPLOTLIB_CALLS == 1:
80        matplotlib.use("Agg")
81        #matplotlib.use("Agg", force=True)  # Use non-graphical display backend during test.
82
83    import matplotlib.pyplot as plt
84    # http://stackoverflow.com/questions/21884271/warning-about-too-many-open-figures
85    plt.close("all")
86
87    backend = matplotlib.get_backend()
88    if backend.lower() != "agg":
89        #raise RuntimeError("matplotlib backend now is %s" % backend)
90        #matplotlib.use("Agg", warn=True, force=False)
91        # Switch the default backend.
92        # This feature is experimental, and is only expected to work switching to an image backend.
93        plt.switch_backend("Agg")
94
95    if version is None: return True
96    return cmp_version(matplotlib.__version__, version, op=op)
97
98
99def has_seaborn():
100    """True if seaborn_ is installed."""
101    try:
102        import seaborn as sns
103        return True
104    except ImportError:
105        return False
106
107
108def has_phonopy(version=None, op=">="):
109    """
110    True if phonopy_ is installed.
111    If version is None, the result of phonopy.__version__ `op` version is returned.
112    """
113    try:
114        import phonopy
115    except ImportError:
116        print("Skipping phonopy test")
117        return False
118
119    if version is None: return True
120    return cmp_version(phonopy.__version__, version, op=op)
121
122
123def get_mock_module():
124    """Return mock module for testing. Raises ImportError if not found."""
125    try:
126        # py > 3.3
127        from unittest import mock
128    except ImportError:
129        try:
130            import mock
131        except ImportError:
132            print("mock module required for unit tests")
133            print("Use py > 3.3 or install it with `pip install mock` if py2.7")
134            raise
135
136    return mock
137
138
139def json_read_abinit_input_from_path(json_path):
140    """
141    Read a json file from the absolute path ``json_path``, return |AbinitInput| instance.
142    """
143    from abipy.abio.inputs import AbinitInput
144
145    with open(json_path, "rt") as fh:
146        d = json.load(fh)
147
148    # Convert pseudo paths: extract basename and build path in abipy/data/pseudos.
149    for pdict in d["pseudos"]:
150        pdict["filepath"] = os.path.join(abidata.dirpath, "pseudos", os.path.basename(pdict["filepath"]))
151
152    return AbinitInput.from_dict(d)
153
154
155def input_equality_check(ref_file, input2, rtol=1e-05, atol=1e-08, equal_nan=False):
156    """
157    Function to compare two inputs
158    ref_file takes the path to reference input in json: json.dump(input.as_dict(), fp, indent=2)
159    input2 takes an AbinintInput object
160    tol relative tolerance for floats
161    we check if all vars are uniquely present in both inputs and if the values are equal (integers, strings)
162    or almost equal (floats)
163    """
164    def check_int(i, j):
165        return i != j
166
167    def check_float(x, y):
168        return not numpy.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan)
169
170    def check_str(s, t):
171        return s != t
172
173    def check_var(v, w):
174        _error = False
175        if isinstance(v, int):
176            _error = check_int(v, w)
177        elif isinstance(v, float):
178            _error = check_float(v, w)
179        elif is_string(v):
180            _error = check_str(v, w)
181        return _error
182
183    def flatten_var(o, tree_types=(list, tuple, numpy.ndarray)):
184        flat_var = []
185        if isinstance(o, tree_types):
186            for value in o:
187                for sub_value in flatten_var(value, tree_types):
188                    flat_var.append(sub_value)
189        else:
190            flat_var.append(o)
191        return flat_var
192
193    input_ref = json_read_abinit_input_from_path(os.path.join(root, '..', 'test_files', ref_file))
194
195    errors = []
196    diff_in_ref = [var for var in input_ref.vars if var not in input2.vars]
197    diff_in_actual = [var for var in input2.vars if var not in input_ref.vars]
198    if len(diff_in_ref) > 0 or len(diff_in_actual) > 0:
199        error_description = 'not the same input parameters:\n' \
200                            '     %s were found in ref but not in actual\n' \
201                            '     %s were found in actual but not in ref\n' % \
202                            (diff_in_ref, diff_in_actual)
203        errors.append(error_description)
204
205    for var, val_r in input_ref.vars.items():
206        try:
207            val_t = input2.vars[var]
208        except KeyError:
209            errors.append('variable %s from the reference is not in the actual input\n' % str(var))
210            continue
211        val_list_t = flatten_var(val_t)
212        val_list_r = flatten_var(val_r)
213        error = False
214        #print(var)
215        #print(val_list_r, type(val_list_r[0]))
216        #print(val_list_t, type(val_list_t[0]))
217        for k, var_item in enumerate(val_list_r):
218            try:
219                error = error or check_var(val_list_t[k], val_list_r[k])
220            except IndexError:
221                #print(val_list_t, type(val_list_t[0]))
222                #print(val_list_r, type(val_list_r[0]))
223                raise RuntimeError('two value lists were not flattened in the same way, try to add the collection'
224                                   'type to the tree_types tuple in flatten_var')
225
226        if error:
227            error_description = 'var %s differs: %s (reference) != %s (actual)' % \
228                                (var, val_r, val_t)
229            errors.append(error_description)
230
231    if input2.structure != input_ref.structure:
232        errors.append('Structures are not the same.\n')
233        print(input2.structure, input_ref.structure)
234
235    if len(errors) > 0:
236        msg = 'Two inputs were found to be not equal:\n'
237        for err in errors:
238            msg += '   ' + err + '\n'
239        raise AssertionError(msg)
240
241
242def get_gsinput_si(usepaw=0, as_task=False):
243    """
244    Build and return a GS input file for silicon or a Task if `as_task`
245    """
246    pseudos = abidata.pseudos("14si.pspnc") if usepaw == 0 else abidata.pseudos("Si.GGA_PBE-JTH-paw.xml")
247    silicon = abidata.cif_file("si.cif")
248
249    from abipy.abio.inputs import AbinitInput
250    scf_input = AbinitInput(silicon, pseudos)
251    ecut = 6
252    scf_input.set_vars(
253        ecut=ecut,
254        nband=6,
255        paral_kgb=0,
256        iomode=3,
257        toldfe=1e-9,
258    )
259    if usepaw:
260        scf_input.set_vars(pawecutdg=4 * ecut)
261
262    # K-point sampling (shifted)
263    scf_input.set_autokmesh(nksmall=4)
264
265    if not as_task:
266        return scf_input
267    else:
268        from abipy.flowtk.tasks import ScfTask
269        return ScfTask(scf_input)
270
271
272def get_gsinput_alas_ngkpt(ngkpt, usepaw=0, as_task=False):
273    """
274    Build and return a GS input file for AlAs or a Task if `as_task`
275    """
276    if usepaw != 0: raise NotImplementedError("PAW")
277    pseudos = abidata.pseudos("13al.981214.fhi", "33as.pspnc")
278    structure = abidata.structure_from_ucell("AlAs")
279
280    from abipy.abio.inputs import AbinitInput
281    scf_input = AbinitInput(structure, pseudos=pseudos)
282
283    scf_input.set_vars(
284        nband=5,
285        ecut=8.0,
286        ngkpt=ngkpt,
287        nshiftk=1,
288        shiftk=[0, 0, 0],
289        tolvrs=1.0e-6,
290        diemac=12.0,
291    )
292
293    if not as_task:
294        return scf_input
295    else:
296        from abipy.flowtk.tasks import ScfTask
297        return ScfTask(scf_input)
298
299
300class AbipyTest(PymatgenTest):
301    """
302    Extends PymatgenTest with Abinit-specific methods.
303    Several helper functions are implemented as static methods so that we
304    can easily reuse the code in the pytest integration tests.
305    """
306
307    SkipTest = unittest.SkipTest
308
309    @staticmethod
310    def which(program):
311        """Returns full path to a executable. None if not found or not executable."""
312        return which(program)
313
314    @staticmethod
315    def has_abinit(version=None, op=">="):
316        """Return True if abinit is in $PATH and version is op min_version."""
317        return has_abinit(version=version, op=op)
318
319    def skip_if_abinit_not_ge(self, version):
320        """Skip test if Abinit version is not >= `version`"""
321        op = ">="
322        if not self.has_abinit(version, op=op):
323            raise unittest.SkipTest("This test requires Abinit version %s %s" % (op, version))
324
325    @staticmethod
326    def has_matplotlib(version=None, op=">="):
327        return has_matplotlib(version=version, op=op)
328
329    @staticmethod
330    def has_seaborn():
331        return has_seaborn()
332
333    @staticmethod
334    def has_ase(version=None, op=">="):
335        """True if ASE_ package is available."""
336        try:
337            import ase
338        except ImportError:
339            return False
340
341        if version is None: return True
342        return cmp_version(ase.__version__, version, op=op)
343
344    @staticmethod
345    def has_skimage():
346        """True if skimage package is available."""
347        try:
348            from skimage import measure
349            return True
350        except ImportError:
351            return False
352
353    @staticmethod
354    def has_python_graphviz(need_dotexec=True):
355        """
356        True if python-graphviz package is installed and dot executable in path.
357        """
358        try:
359            from graphviz import Digraph
360        except ImportError:
361            return False
362
363        return which("dot") is not None if need_dotexec else True
364
365    @staticmethod
366    def has_mayavi():
367        """
368        True if mayavi_ is available. Set also offscreen to True
369        """
370        # Disable mayavi for the time being.
371        #return False
372        # This to run mayavi tests only on Travis
373        if not os.environ.get("TRAVIS"): return False
374        try:
375            from mayavi import mlab
376        except ImportError:
377            return False
378
379        #mlab.clf()
380        mlab.options.offscreen = True
381        mlab.options.backend = "test"
382        return True
383
384    def has_panel(self):
385        """False if Panel library is not installed."""
386        try:
387            import param
388            import panel as pn
389            import bokeh
390            return pn
391        except ImportError:
392            return False
393
394    def has_networkx(self):
395        """False if networkx library is not installed."""
396        try:
397            import networkx as nx
398            return nx
399        except ImportError:
400            return False
401
402    def has_graphviz(self):
403        """True if graphviz library is installed and `dot` in $PATH"""
404        try:
405            from graphviz import Digraph
406            import graphviz
407        except ImportError:
408            return False
409
410        if self.which("dot") is None: return False
411        return graphviz
412
413    @staticmethod
414    def get_abistructure_from_abiref(basename):
415        """Return an Abipy |Structure| from the basename of one of the reference files."""
416        from abipy.core.structure import Structure
417        return Structure.as_structure(abidata.ref_file(basename))
418
419    @staticmethod
420    def mkdtemp(**kwargs):
421        """Invoke mkdtep with kwargs, return the name of a temporary directory."""
422        return tempfile.mkdtemp(**kwargs)
423
424    @staticmethod
425    def tmpfileindir(basename, **kwargs):
426        """
427        Return the absolute path of a temporary file with basename ``basename`` created in a temporary directory.
428        """
429        tmpdir = tempfile.mkdtemp(**kwargs)
430        return os.path.join(tmpdir, basename)
431
432    @staticmethod
433    def get_tmpname(**kwargs):
434        """Invoke mkstep with kwargs, return the name of a temporary file."""
435        _, tmpname = tempfile.mkstemp(**kwargs)
436        return tmpname
437
438    def tmpfile_write(self, string):
439        """
440        Write string to a temporary file. Returns the name of the temporary file.
441        """
442        fd, tmpfile = tempfile.mkstemp(text=True)
443
444        with open(tmpfile, "w") as fh:
445            fh.write(string)
446
447        return tmpfile
448
449    @staticmethod
450    def has_nbformat():
451        """Return True if nbformat is available and we can test the generation of jupyter_ notebooks."""
452        try:
453            import nbformat
454            return True
455        except ImportError:
456            return False
457
458    def run_nbpath(self, nbpath):
459        """Test that the notebook in question runs all cells correctly."""
460        nb, errors = notebook_run(nbpath)
461        return nb, errors
462
463    @staticmethod
464    def has_ipywidgets():
465        """Return True if ipywidgets_ package is available."""
466        # Disabled due to:
467        # AttributeError: 'NoneType' object has no attribute 'session'
468        return False
469        # Disable widget tests on TRAVIS
470        #if os.environ.get("TRAVIS"): return False
471        try:
472            import ipywidgets as ipw
473            return True
474        except ImportError:
475            return False
476
477    @staticmethod
478    def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
479        """
480        Alternative naming for assertArrayAlmostEqual.
481        """
482        return nptu.assert_almost_equal(actual, desired, decimal, err_msg, verbose)
483
484    @staticmethod
485    def assert_equal(actual, desired, err_msg='', verbose=True):
486        """
487        Alternative naming for assertArrayEqual.
488        """
489        return nptu.assert_equal(actual, desired, err_msg=err_msg, verbose=verbose)
490
491    @staticmethod
492    def json_read_abinit_input(json_basename):
493        """Return an |AbinitInput| from the basename of the file in abipy/data/test_files."""
494        return json_read_abinit_input_from_path(os.path.join(root, '..', 'test_files', json_basename))
495
496    @staticmethod
497    def assert_input_equality(ref_basename, input_to_test, rtol=1e-05, atol=1e-08, equal_nan=False):
498        """
499        Check equality between an input and a reference in test_files.
500        only input variables and structure are compared.
501
502        Args:
503            ref_basename: base name of the reference file to test against in test_files
504            input_to_test: |AbinitInput| object to test
505            rtol: passed to numpy.isclose for float comparison
506            atol: passed to numpy.isclose for float comparison
507            equal_nan: passed to numpy.isclose for float comparison
508
509        Returns:
510            raises an assertion error if the two inputs are not the same
511        """
512        ref_file = os.path.join(root, '..', 'test_files', ref_basename)
513        input_equality_check(ref_file, input_to_test, rtol=rtol, atol=atol, equal_nan=equal_nan)
514
515    @staticmethod
516    def straceback():
517        """Returns a string with the traceback."""
518        import traceback
519        return traceback.format_exc()
520
521    @staticmethod
522    def skip_if_not_phonopy(version=None, op=">="):
523        """
524        Raise SkipTest if phonopy_ is not installed.
525        Use ``version`` and ``op`` to ask for a specific version
526        """
527        if not has_phonopy(version=version, op=op):
528            if version is None:
529                msg = "This test requires phonopy"
530            else:
531                msg = "This test requires phonopy version %s %s" % (op, version)
532            raise unittest.SkipTest(msg)
533
534    @staticmethod
535    def skip_if_not_bolztrap2(version=None, op=">="):
536        """
537        Raise SkipTest if bolztrap2 is not installed.
538        Use ``version`` and ``op`` to ask for a specific version
539        """
540        try:
541            import BoltzTraP2 as bzt
542        except ImportError:
543            raise unittest.SkipTest("This test requires bolztrap2")
544
545        from BoltzTraP2.version import PROGRAM_VERSION
546        if version is not None and not cmp_version(PROGRAM_VERSION, version, op=op):
547            msg = "This test requires bolztrap2 version %s %s" % (op, version)
548            raise unittest.SkipTest(msg)
549
550    def skip_if_not_executable(self, executable):
551        """
552        Raise SkipTest if executable is not installed.
553        """
554        if self.which(executable) is None:
555            raise unittest.SkipTest("This test requires `%s` in PATH" % str(executable))
556
557    @staticmethod
558    def skip_if_not_pseudodojo():
559        """
560        Raise SkipTest if pseudodojo package is not installed.
561        """
562        try:
563            from pseudo_dojo import OfficialTables
564        except ImportError:
565            raise unittest.SkipTest("This test requires pseudodojo package.")
566
567    @staticmethod
568    def get_mock_module():
569        """Return mock module for testing. Raises ImportError if not found."""
570        return get_mock_module()
571
572    def decode_with_MSON(self, obj):
573        """
574        Convert obj into JSON assuming MSONable protocolo. Return new object decoded with MontyDecoder
575        """
576        from monty.json import MSONable, MontyDecoder
577        self.assertIsInstance(obj, MSONable)
578        return json.loads(obj.to_json(), cls=MontyDecoder)
579
580    @staticmethod
581    def abivalidate_input(abinput, must_fail=False):
582        """
583        Invoke Abinit to test validity of an |AbinitInput| object
584        Print info to stdout if failure before raising AssertionError.
585        """
586        v = abinput.abivalidate()
587        if must_fail:
588            assert v.retcode != 0 and v.log_file.read()
589        else:
590            if v.retcode != 0:
591                print("type abinput:", type(abinput))
592                print("abinput:\n", abinput)
593                lines = v.log_file.readlines()
594                i = len(lines) - 50 if len(lines) >= 50 else 0
595                print("Last 50 line from logfile:")
596                print("".join(lines[i:]))
597
598            assert v.retcode == 0
599
600    @staticmethod
601    def abivalidate_multi(multi):
602        """
603        Invoke Abinit to test validity of a |MultiDataset| or a list of |AbinitInput| objects.
604        """
605        if hasattr(multi, "split_datasets"):
606            inputs = multi.split_datasets()
607        else:
608            inputs = multi
609
610        errors = []
611        for inp in inputs:
612            try:
613                AbipyTest.abivalidate_input(inp)
614            except Exception as exc:
615                errors.append(AbipyTest.straceback())
616                errors.append(str(exc))
617
618        if errors:
619            for e in errors:
620                print(90 * "=")
621                print(e)
622                print(90 * "=")
623
624        assert not errors
625
626    def abivalidate_work(self, work):
627        """Invoke Abinit to test validity of the inputs of a |Work|"""
628        from abipy.flowtk import Flow
629        tmpdir = tempfile.mkdtemp()
630        flow = Flow(workdir=tmpdir)
631        flow.register_work(work)
632        return self.abivalidate_flow(flow)
633
634    @staticmethod
635    def abivalidate_flow(flow):
636        """
637        Invoke Abinit to test validity of the inputs of a |Flow|
638        """
639        isok, errors = flow.abivalidate_inputs()
640        if not isok:
641            for e in errors:
642                if e.retcode == 0: continue
643                #print("type abinput:", type(abinput))
644                #print("abinput:\n", abinput)
645                lines = e.log_file.readlines()
646                i = len(lines) - 50 if len(lines) >= 50 else 0
647                print("Last 50 line from logfile:")
648                print("".join(lines[i:]))
649            raise RuntimeError("flow.abivalidate_input failed. See messages above.")
650
651    @staticmethod
652    @wraps(get_gsinput_si)
653    def get_gsinput_si(*args, **kwargs):
654        return get_gsinput_si(*args, **kwargs)
655
656    @staticmethod
657    @wraps(get_gsinput_alas_ngkpt)
658    def get_gsinput_alas_ngkpt(*args, **kwargs):
659        return get_gsinput_alas_ngkpt(*args, **kwargs)
660
661
662def notebook_run(path):
663    """
664    Execute a notebook via nbconvert and collect output.
665
666    Taken from
667    https://blog.thedataincubator.com/2016/06/testing-jupyter-notebooks/
668
669    Args:
670        path (str): file path for the notebook object
671
672    Returns: (parsed nb object, execution errors)
673
674    """
675    import nbformat
676    dirname, __ = os.path.split(path)
677    os.chdir(dirname)
678    with tempfile.NamedTemporaryFile(suffix=".ipynb") as fout:
679        args = ["jupyter", "nbconvert", "--to", "notebook", "--execute",
680                "--ExecutePreprocessor.timeout=300",
681                "--ExecutePreprocessor.allow_errors=True",
682                "--output", fout.name, path]
683        subprocess.check_call(args)
684
685        fout.seek(0)
686        nb = nbformat.read(fout, nbformat.current_nbformat)
687
688    errors = [output for cell in nb.cells if "outputs" in cell
689              for output in cell["outputs"] if output.output_type == "error"]
690
691    return nb, errors
692