1# coding: utf-8 2# flake8: noqa 3""" 4Common test support for all AbiPy test scripts. 5 6This single module should provide all the common functionality for abipy tests 7in a single location, so that test scripts can just import it and work right away. 8""" 9import os 10import numpy 11import subprocess 12import json 13import tempfile 14import unittest 15try: 16 import numpy.testing as nptu 17except ImportError: 18 import numpy.testing.utils as nptu 19import abipy.data as abidata 20 21from functools import wraps 22from monty.os.path import which 23from monty.string import is_string 24from pymatgen.util.testing import PymatgenTest 25 26import logging 27logger = logging.getLogger(__file__) 28 29root = os.path.dirname(__file__) 30 31__all__ = [ 32 "AbipyTest", 33] 34 35 36def cmp_version(this, other, op=">="): 37 """ 38 Compare two version strings with the given operator ``op`` 39 >>> assert cmp_version("1.1.1", "1.1.0") and not cmp_version("1.1.1", "1.1.0", op="==") 40 """ 41 from pkg_resources import parse_version 42 from monty.operator import operator_from_str 43 op = operator_from_str(op) 44 return op(parse_version(this), parse_version(other)) 45 46 47def has_abinit(version=None, op=">=", manager=None): 48 """ 49 True if abinit is available via TaskManager configuration options. 50 If version is not None, `abinit_version op version` is evaluated and the result is returned. 51 """ 52 from abipy.flowtk import TaskManager, AbinitBuild 53 manager = TaskManager.from_user_config() if manager is None else manager 54 build = AbinitBuild(manager=manager) 55 if version is None: 56 return build.version != "0.0.0" 57 else: 58 return cmp_version(build.version, version, op=op) 59 60 61_HAS_MATPLOTLIB_CALLS = 0 62 63 64def has_matplotlib(version=None, op=">="): 65 """ 66 True if matplotlib_ is installed. 67 If version is None, the result of matplotlib.__version__ `op` version is returned. 68 """ 69 try: 70 import matplotlib 71 # have_display = "DISPLAY" in os.environ 72 except ImportError: 73 print("Skipping matplotlib test") 74 return False 75 76 global _HAS_MATPLOTLIB_CALLS 77 _HAS_MATPLOTLIB_CALLS += 1 78 79 if _HAS_MATPLOTLIB_CALLS == 1: 80 matplotlib.use("Agg") 81 #matplotlib.use("Agg", force=True) # Use non-graphical display backend during test. 82 83 import matplotlib.pyplot as plt 84 # http://stackoverflow.com/questions/21884271/warning-about-too-many-open-figures 85 plt.close("all") 86 87 backend = matplotlib.get_backend() 88 if backend.lower() != "agg": 89 #raise RuntimeError("matplotlib backend now is %s" % backend) 90 #matplotlib.use("Agg", warn=True, force=False) 91 # Switch the default backend. 92 # This feature is experimental, and is only expected to work switching to an image backend. 93 plt.switch_backend("Agg") 94 95 if version is None: return True 96 return cmp_version(matplotlib.__version__, version, op=op) 97 98 99def has_seaborn(): 100 """True if seaborn_ is installed.""" 101 try: 102 import seaborn as sns 103 return True 104 except ImportError: 105 return False 106 107 108def has_phonopy(version=None, op=">="): 109 """ 110 True if phonopy_ is installed. 111 If version is None, the result of phonopy.__version__ `op` version is returned. 112 """ 113 try: 114 import phonopy 115 except ImportError: 116 print("Skipping phonopy test") 117 return False 118 119 if version is None: return True 120 return cmp_version(phonopy.__version__, version, op=op) 121 122 123def get_mock_module(): 124 """Return mock module for testing. Raises ImportError if not found.""" 125 try: 126 # py > 3.3 127 from unittest import mock 128 except ImportError: 129 try: 130 import mock 131 except ImportError: 132 print("mock module required for unit tests") 133 print("Use py > 3.3 or install it with `pip install mock` if py2.7") 134 raise 135 136 return mock 137 138 139def json_read_abinit_input_from_path(json_path): 140 """ 141 Read a json file from the absolute path ``json_path``, return |AbinitInput| instance. 142 """ 143 from abipy.abio.inputs import AbinitInput 144 145 with open(json_path, "rt") as fh: 146 d = json.load(fh) 147 148 # Convert pseudo paths: extract basename and build path in abipy/data/pseudos. 149 for pdict in d["pseudos"]: 150 pdict["filepath"] = os.path.join(abidata.dirpath, "pseudos", os.path.basename(pdict["filepath"])) 151 152 return AbinitInput.from_dict(d) 153 154 155def input_equality_check(ref_file, input2, rtol=1e-05, atol=1e-08, equal_nan=False): 156 """ 157 Function to compare two inputs 158 ref_file takes the path to reference input in json: json.dump(input.as_dict(), fp, indent=2) 159 input2 takes an AbinintInput object 160 tol relative tolerance for floats 161 we check if all vars are uniquely present in both inputs and if the values are equal (integers, strings) 162 or almost equal (floats) 163 """ 164 def check_int(i, j): 165 return i != j 166 167 def check_float(x, y): 168 return not numpy.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) 169 170 def check_str(s, t): 171 return s != t 172 173 def check_var(v, w): 174 _error = False 175 if isinstance(v, int): 176 _error = check_int(v, w) 177 elif isinstance(v, float): 178 _error = check_float(v, w) 179 elif is_string(v): 180 _error = check_str(v, w) 181 return _error 182 183 def flatten_var(o, tree_types=(list, tuple, numpy.ndarray)): 184 flat_var = [] 185 if isinstance(o, tree_types): 186 for value in o: 187 for sub_value in flatten_var(value, tree_types): 188 flat_var.append(sub_value) 189 else: 190 flat_var.append(o) 191 return flat_var 192 193 input_ref = json_read_abinit_input_from_path(os.path.join(root, '..', 'test_files', ref_file)) 194 195 errors = [] 196 diff_in_ref = [var for var in input_ref.vars if var not in input2.vars] 197 diff_in_actual = [var for var in input2.vars if var not in input_ref.vars] 198 if len(diff_in_ref) > 0 or len(diff_in_actual) > 0: 199 error_description = 'not the same input parameters:\n' \ 200 ' %s were found in ref but not in actual\n' \ 201 ' %s were found in actual but not in ref\n' % \ 202 (diff_in_ref, diff_in_actual) 203 errors.append(error_description) 204 205 for var, val_r in input_ref.vars.items(): 206 try: 207 val_t = input2.vars[var] 208 except KeyError: 209 errors.append('variable %s from the reference is not in the actual input\n' % str(var)) 210 continue 211 val_list_t = flatten_var(val_t) 212 val_list_r = flatten_var(val_r) 213 error = False 214 #print(var) 215 #print(val_list_r, type(val_list_r[0])) 216 #print(val_list_t, type(val_list_t[0])) 217 for k, var_item in enumerate(val_list_r): 218 try: 219 error = error or check_var(val_list_t[k], val_list_r[k]) 220 except IndexError: 221 #print(val_list_t, type(val_list_t[0])) 222 #print(val_list_r, type(val_list_r[0])) 223 raise RuntimeError('two value lists were not flattened in the same way, try to add the collection' 224 'type to the tree_types tuple in flatten_var') 225 226 if error: 227 error_description = 'var %s differs: %s (reference) != %s (actual)' % \ 228 (var, val_r, val_t) 229 errors.append(error_description) 230 231 if input2.structure != input_ref.structure: 232 errors.append('Structures are not the same.\n') 233 print(input2.structure, input_ref.structure) 234 235 if len(errors) > 0: 236 msg = 'Two inputs were found to be not equal:\n' 237 for err in errors: 238 msg += ' ' + err + '\n' 239 raise AssertionError(msg) 240 241 242def get_gsinput_si(usepaw=0, as_task=False): 243 """ 244 Build and return a GS input file for silicon or a Task if `as_task` 245 """ 246 pseudos = abidata.pseudos("14si.pspnc") if usepaw == 0 else abidata.pseudos("Si.GGA_PBE-JTH-paw.xml") 247 silicon = abidata.cif_file("si.cif") 248 249 from abipy.abio.inputs import AbinitInput 250 scf_input = AbinitInput(silicon, pseudos) 251 ecut = 6 252 scf_input.set_vars( 253 ecut=ecut, 254 nband=6, 255 paral_kgb=0, 256 iomode=3, 257 toldfe=1e-9, 258 ) 259 if usepaw: 260 scf_input.set_vars(pawecutdg=4 * ecut) 261 262 # K-point sampling (shifted) 263 scf_input.set_autokmesh(nksmall=4) 264 265 if not as_task: 266 return scf_input 267 else: 268 from abipy.flowtk.tasks import ScfTask 269 return ScfTask(scf_input) 270 271 272def get_gsinput_alas_ngkpt(ngkpt, usepaw=0, as_task=False): 273 """ 274 Build and return a GS input file for AlAs or a Task if `as_task` 275 """ 276 if usepaw != 0: raise NotImplementedError("PAW") 277 pseudos = abidata.pseudos("13al.981214.fhi", "33as.pspnc") 278 structure = abidata.structure_from_ucell("AlAs") 279 280 from abipy.abio.inputs import AbinitInput 281 scf_input = AbinitInput(structure, pseudos=pseudos) 282 283 scf_input.set_vars( 284 nband=5, 285 ecut=8.0, 286 ngkpt=ngkpt, 287 nshiftk=1, 288 shiftk=[0, 0, 0], 289 tolvrs=1.0e-6, 290 diemac=12.0, 291 ) 292 293 if not as_task: 294 return scf_input 295 else: 296 from abipy.flowtk.tasks import ScfTask 297 return ScfTask(scf_input) 298 299 300class AbipyTest(PymatgenTest): 301 """ 302 Extends PymatgenTest with Abinit-specific methods. 303 Several helper functions are implemented as static methods so that we 304 can easily reuse the code in the pytest integration tests. 305 """ 306 307 SkipTest = unittest.SkipTest 308 309 @staticmethod 310 def which(program): 311 """Returns full path to a executable. None if not found or not executable.""" 312 return which(program) 313 314 @staticmethod 315 def has_abinit(version=None, op=">="): 316 """Return True if abinit is in $PATH and version is op min_version.""" 317 return has_abinit(version=version, op=op) 318 319 def skip_if_abinit_not_ge(self, version): 320 """Skip test if Abinit version is not >= `version`""" 321 op = ">=" 322 if not self.has_abinit(version, op=op): 323 raise unittest.SkipTest("This test requires Abinit version %s %s" % (op, version)) 324 325 @staticmethod 326 def has_matplotlib(version=None, op=">="): 327 return has_matplotlib(version=version, op=op) 328 329 @staticmethod 330 def has_seaborn(): 331 return has_seaborn() 332 333 @staticmethod 334 def has_ase(version=None, op=">="): 335 """True if ASE_ package is available.""" 336 try: 337 import ase 338 except ImportError: 339 return False 340 341 if version is None: return True 342 return cmp_version(ase.__version__, version, op=op) 343 344 @staticmethod 345 def has_skimage(): 346 """True if skimage package is available.""" 347 try: 348 from skimage import measure 349 return True 350 except ImportError: 351 return False 352 353 @staticmethod 354 def has_python_graphviz(need_dotexec=True): 355 """ 356 True if python-graphviz package is installed and dot executable in path. 357 """ 358 try: 359 from graphviz import Digraph 360 except ImportError: 361 return False 362 363 return which("dot") is not None if need_dotexec else True 364 365 @staticmethod 366 def has_mayavi(): 367 """ 368 True if mayavi_ is available. Set also offscreen to True 369 """ 370 # Disable mayavi for the time being. 371 #return False 372 # This to run mayavi tests only on Travis 373 if not os.environ.get("TRAVIS"): return False 374 try: 375 from mayavi import mlab 376 except ImportError: 377 return False 378 379 #mlab.clf() 380 mlab.options.offscreen = True 381 mlab.options.backend = "test" 382 return True 383 384 def has_panel(self): 385 """False if Panel library is not installed.""" 386 try: 387 import param 388 import panel as pn 389 import bokeh 390 return pn 391 except ImportError: 392 return False 393 394 def has_networkx(self): 395 """False if networkx library is not installed.""" 396 try: 397 import networkx as nx 398 return nx 399 except ImportError: 400 return False 401 402 def has_graphviz(self): 403 """True if graphviz library is installed and `dot` in $PATH""" 404 try: 405 from graphviz import Digraph 406 import graphviz 407 except ImportError: 408 return False 409 410 if self.which("dot") is None: return False 411 return graphviz 412 413 @staticmethod 414 def get_abistructure_from_abiref(basename): 415 """Return an Abipy |Structure| from the basename of one of the reference files.""" 416 from abipy.core.structure import Structure 417 return Structure.as_structure(abidata.ref_file(basename)) 418 419 @staticmethod 420 def mkdtemp(**kwargs): 421 """Invoke mkdtep with kwargs, return the name of a temporary directory.""" 422 return tempfile.mkdtemp(**kwargs) 423 424 @staticmethod 425 def tmpfileindir(basename, **kwargs): 426 """ 427 Return the absolute path of a temporary file with basename ``basename`` created in a temporary directory. 428 """ 429 tmpdir = tempfile.mkdtemp(**kwargs) 430 return os.path.join(tmpdir, basename) 431 432 @staticmethod 433 def get_tmpname(**kwargs): 434 """Invoke mkstep with kwargs, return the name of a temporary file.""" 435 _, tmpname = tempfile.mkstemp(**kwargs) 436 return tmpname 437 438 def tmpfile_write(self, string): 439 """ 440 Write string to a temporary file. Returns the name of the temporary file. 441 """ 442 fd, tmpfile = tempfile.mkstemp(text=True) 443 444 with open(tmpfile, "w") as fh: 445 fh.write(string) 446 447 return tmpfile 448 449 @staticmethod 450 def has_nbformat(): 451 """Return True if nbformat is available and we can test the generation of jupyter_ notebooks.""" 452 try: 453 import nbformat 454 return True 455 except ImportError: 456 return False 457 458 def run_nbpath(self, nbpath): 459 """Test that the notebook in question runs all cells correctly.""" 460 nb, errors = notebook_run(nbpath) 461 return nb, errors 462 463 @staticmethod 464 def has_ipywidgets(): 465 """Return True if ipywidgets_ package is available.""" 466 # Disabled due to: 467 # AttributeError: 'NoneType' object has no attribute 'session' 468 return False 469 # Disable widget tests on TRAVIS 470 #if os.environ.get("TRAVIS"): return False 471 try: 472 import ipywidgets as ipw 473 return True 474 except ImportError: 475 return False 476 477 @staticmethod 478 def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True): 479 """ 480 Alternative naming for assertArrayAlmostEqual. 481 """ 482 return nptu.assert_almost_equal(actual, desired, decimal, err_msg, verbose) 483 484 @staticmethod 485 def assert_equal(actual, desired, err_msg='', verbose=True): 486 """ 487 Alternative naming for assertArrayEqual. 488 """ 489 return nptu.assert_equal(actual, desired, err_msg=err_msg, verbose=verbose) 490 491 @staticmethod 492 def json_read_abinit_input(json_basename): 493 """Return an |AbinitInput| from the basename of the file in abipy/data/test_files.""" 494 return json_read_abinit_input_from_path(os.path.join(root, '..', 'test_files', json_basename)) 495 496 @staticmethod 497 def assert_input_equality(ref_basename, input_to_test, rtol=1e-05, atol=1e-08, equal_nan=False): 498 """ 499 Check equality between an input and a reference in test_files. 500 only input variables and structure are compared. 501 502 Args: 503 ref_basename: base name of the reference file to test against in test_files 504 input_to_test: |AbinitInput| object to test 505 rtol: passed to numpy.isclose for float comparison 506 atol: passed to numpy.isclose for float comparison 507 equal_nan: passed to numpy.isclose for float comparison 508 509 Returns: 510 raises an assertion error if the two inputs are not the same 511 """ 512 ref_file = os.path.join(root, '..', 'test_files', ref_basename) 513 input_equality_check(ref_file, input_to_test, rtol=rtol, atol=atol, equal_nan=equal_nan) 514 515 @staticmethod 516 def straceback(): 517 """Returns a string with the traceback.""" 518 import traceback 519 return traceback.format_exc() 520 521 @staticmethod 522 def skip_if_not_phonopy(version=None, op=">="): 523 """ 524 Raise SkipTest if phonopy_ is not installed. 525 Use ``version`` and ``op`` to ask for a specific version 526 """ 527 if not has_phonopy(version=version, op=op): 528 if version is None: 529 msg = "This test requires phonopy" 530 else: 531 msg = "This test requires phonopy version %s %s" % (op, version) 532 raise unittest.SkipTest(msg) 533 534 @staticmethod 535 def skip_if_not_bolztrap2(version=None, op=">="): 536 """ 537 Raise SkipTest if bolztrap2 is not installed. 538 Use ``version`` and ``op`` to ask for a specific version 539 """ 540 try: 541 import BoltzTraP2 as bzt 542 except ImportError: 543 raise unittest.SkipTest("This test requires bolztrap2") 544 545 from BoltzTraP2.version import PROGRAM_VERSION 546 if version is not None and not cmp_version(PROGRAM_VERSION, version, op=op): 547 msg = "This test requires bolztrap2 version %s %s" % (op, version) 548 raise unittest.SkipTest(msg) 549 550 def skip_if_not_executable(self, executable): 551 """ 552 Raise SkipTest if executable is not installed. 553 """ 554 if self.which(executable) is None: 555 raise unittest.SkipTest("This test requires `%s` in PATH" % str(executable)) 556 557 @staticmethod 558 def skip_if_not_pseudodojo(): 559 """ 560 Raise SkipTest if pseudodojo package is not installed. 561 """ 562 try: 563 from pseudo_dojo import OfficialTables 564 except ImportError: 565 raise unittest.SkipTest("This test requires pseudodojo package.") 566 567 @staticmethod 568 def get_mock_module(): 569 """Return mock module for testing. Raises ImportError if not found.""" 570 return get_mock_module() 571 572 def decode_with_MSON(self, obj): 573 """ 574 Convert obj into JSON assuming MSONable protocolo. Return new object decoded with MontyDecoder 575 """ 576 from monty.json import MSONable, MontyDecoder 577 self.assertIsInstance(obj, MSONable) 578 return json.loads(obj.to_json(), cls=MontyDecoder) 579 580 @staticmethod 581 def abivalidate_input(abinput, must_fail=False): 582 """ 583 Invoke Abinit to test validity of an |AbinitInput| object 584 Print info to stdout if failure before raising AssertionError. 585 """ 586 v = abinput.abivalidate() 587 if must_fail: 588 assert v.retcode != 0 and v.log_file.read() 589 else: 590 if v.retcode != 0: 591 print("type abinput:", type(abinput)) 592 print("abinput:\n", abinput) 593 lines = v.log_file.readlines() 594 i = len(lines) - 50 if len(lines) >= 50 else 0 595 print("Last 50 line from logfile:") 596 print("".join(lines[i:])) 597 598 assert v.retcode == 0 599 600 @staticmethod 601 def abivalidate_multi(multi): 602 """ 603 Invoke Abinit to test validity of a |MultiDataset| or a list of |AbinitInput| objects. 604 """ 605 if hasattr(multi, "split_datasets"): 606 inputs = multi.split_datasets() 607 else: 608 inputs = multi 609 610 errors = [] 611 for inp in inputs: 612 try: 613 AbipyTest.abivalidate_input(inp) 614 except Exception as exc: 615 errors.append(AbipyTest.straceback()) 616 errors.append(str(exc)) 617 618 if errors: 619 for e in errors: 620 print(90 * "=") 621 print(e) 622 print(90 * "=") 623 624 assert not errors 625 626 def abivalidate_work(self, work): 627 """Invoke Abinit to test validity of the inputs of a |Work|""" 628 from abipy.flowtk import Flow 629 tmpdir = tempfile.mkdtemp() 630 flow = Flow(workdir=tmpdir) 631 flow.register_work(work) 632 return self.abivalidate_flow(flow) 633 634 @staticmethod 635 def abivalidate_flow(flow): 636 """ 637 Invoke Abinit to test validity of the inputs of a |Flow| 638 """ 639 isok, errors = flow.abivalidate_inputs() 640 if not isok: 641 for e in errors: 642 if e.retcode == 0: continue 643 #print("type abinput:", type(abinput)) 644 #print("abinput:\n", abinput) 645 lines = e.log_file.readlines() 646 i = len(lines) - 50 if len(lines) >= 50 else 0 647 print("Last 50 line from logfile:") 648 print("".join(lines[i:])) 649 raise RuntimeError("flow.abivalidate_input failed. See messages above.") 650 651 @staticmethod 652 @wraps(get_gsinput_si) 653 def get_gsinput_si(*args, **kwargs): 654 return get_gsinput_si(*args, **kwargs) 655 656 @staticmethod 657 @wraps(get_gsinput_alas_ngkpt) 658 def get_gsinput_alas_ngkpt(*args, **kwargs): 659 return get_gsinput_alas_ngkpt(*args, **kwargs) 660 661 662def notebook_run(path): 663 """ 664 Execute a notebook via nbconvert and collect output. 665 666 Taken from 667 https://blog.thedataincubator.com/2016/06/testing-jupyter-notebooks/ 668 669 Args: 670 path (str): file path for the notebook object 671 672 Returns: (parsed nb object, execution errors) 673 674 """ 675 import nbformat 676 dirname, __ = os.path.split(path) 677 os.chdir(dirname) 678 with tempfile.NamedTemporaryFile(suffix=".ipynb") as fout: 679 args = ["jupyter", "nbconvert", "--to", "notebook", "--execute", 680 "--ExecutePreprocessor.timeout=300", 681 "--ExecutePreprocessor.allow_errors=True", 682 "--output", fout.name, path] 683 subprocess.check_call(args) 684 685 fout.seek(0) 686 nb = nbformat.read(fout, nbformat.current_nbformat) 687 688 errors = [output for cell in nb.cells if "outputs" in cell 689 for output in cell["outputs"] if output.output_type == "error"] 690 691 return nb, errors 692