1""" 2Title: framework.py 3Purpose: Contains answer tests that are used by yt's various frontends 4""" 5import contextlib 6import glob 7import hashlib 8import logging 9import os 10import pickle 11import shelve 12import sys 13import tempfile 14import time 15import urllib 16import zlib 17from collections import defaultdict 18 19import numpy as np 20from matplotlib import image as mpimg 21from matplotlib.testing.compare import compare_images 22from nose.plugins import Plugin 23 24from yt._maintenance.deprecation import issue_deprecation_warning 25from yt.config import ytcfg 26from yt.data_objects.static_output import Dataset 27from yt.data_objects.time_series import SimulationTimeSeries 28from yt.funcs import get_pbar, get_yt_version 29from yt.loaders import load, load_simulation 30from yt.testing import ( 31 assert_allclose_units, 32 assert_almost_equal, 33 assert_equal, 34 assert_rel_equal, 35) 36from yt.utilities.exceptions import YTCloudError, YTNoAnswerNameSpecified, YTNoOldAnswer 37from yt.utilities.logger import disable_stream_logging 38from yt.visualization import ( 39 image_writer as image_writer, 40 particle_plots as particle_plots, 41 plot_window as pw, 42 profile_plotter as profile_plotter, 43) 44 45mylog = logging.getLogger("nose.plugins.answer-testing") 46run_big_data = False 47 48# Set the latest gold and local standard filenames 49_latest = ytcfg.get("yt", "gold_standard_filename") 50_latest_local = ytcfg.get("yt", "local_standard_filename") 51_url_path = ytcfg.get("yt", "answer_tests_url") 52 53 54class AnswerTesting(Plugin): 55 name = "answer-testing" 56 _my_version = None 57 58 def options(self, parser, env=os.environ): 59 super().options(parser, env=env) 60 parser.add_option( 61 "--answer-name", 62 dest="answer_name", 63 metavar="str", 64 default=None, 65 help="The name of the standard to store/compare against", 66 ) 67 parser.add_option( 68 "--answer-store", 69 dest="store_results", 70 metavar="bool", 71 default=False, 72 action="store_true", 73 help="Should we store this result instead of comparing?", 74 ) 75 parser.add_option( 76 "--local", 77 dest="local_results", 78 default=False, 79 action="store_true", 80 help="Store/load reference results locally?", 81 ) 82 parser.add_option( 83 "--answer-big-data", 84 dest="big_data", 85 default=False, 86 help="Should we run against big data, too?", 87 action="store_true", 88 ) 89 parser.add_option( 90 "--local-dir", 91 dest="output_dir", 92 metavar="str", 93 help="The name of the directory to store local results", 94 ) 95 96 @property 97 def my_version(self, version=None): 98 if self._my_version is not None: 99 return self._my_version 100 if version is None: 101 try: 102 version = get_yt_version() 103 except Exception: 104 version = f"UNKNOWN{time.time()}" 105 self._my_version = version 106 return self._my_version 107 108 def configure(self, options, conf): 109 super().configure(options, conf) 110 if not self.enabled: 111 return 112 disable_stream_logging() 113 114 # Parse through the storage flags to make sense of them 115 # and use reasonable defaults 116 # If we're storing the data, default storage name is local 117 # latest version 118 if options.store_results: 119 if options.answer_name is None: 120 self.store_name = _latest_local 121 else: 122 self.store_name = options.answer_name 123 self.compare_name = None 124 # if we're not storing, then we're comparing, and we want default 125 # comparison name to be the latest gold standard 126 # either on network or local 127 else: 128 if options.answer_name is None: 129 if options.local_results: 130 self.compare_name = _latest_local 131 else: 132 self.compare_name = _latest 133 else: 134 self.compare_name = options.answer_name 135 self.store_name = self.my_version 136 137 self.store_results = options.store_results 138 139 ytcfg["yt", "internals", "within_testing"] = True 140 AnswerTestingTest.result_storage = self.result_storage = defaultdict(dict) 141 if self.compare_name == "SKIP": 142 self.compare_name = None 143 elif self.compare_name == "latest": 144 self.compare_name = _latest 145 146 # Local/Cloud storage 147 if options.local_results: 148 if options.output_dir is None: 149 print("Please supply an output directory with the --local-dir option") 150 sys.exit(1) 151 storage_class = AnswerTestLocalStorage 152 output_dir = os.path.realpath(options.output_dir) 153 # Fix up filename for local storage 154 if self.compare_name is not None: 155 self.compare_name = os.path.join( 156 output_dir, self.compare_name, self.compare_name 157 ) 158 159 # Create a local directory only when `options.answer_name` is 160 # provided. If it is not provided then creating local directory 161 # will depend on the `AnswerTestingTest.answer_name` value of the 162 # test, this case is handled in AnswerTestingTest class. 163 if options.store_results and options.answer_name is not None: 164 name_dir_path = os.path.join(output_dir, self.store_name) 165 if not os.path.isdir(name_dir_path): 166 os.makedirs(name_dir_path) 167 self.store_name = os.path.join(name_dir_path, self.store_name) 168 else: 169 storage_class = AnswerTestCloudStorage 170 171 # Initialize answer/reference storage 172 AnswerTestingTest.reference_storage = self.storage = storage_class( 173 self.compare_name, self.store_name 174 ) 175 AnswerTestingTest.options = options 176 177 self.local_results = options.local_results 178 global run_big_data 179 run_big_data = options.big_data 180 181 def finalize(self, result=None): 182 if not self.store_results: 183 return 184 self.storage.dump(self.result_storage) 185 186 def help(self): 187 return "yt answer testing support" 188 189 190class AnswerTestStorage: 191 def __init__(self, reference_name=None, answer_name=None): 192 self.reference_name = reference_name 193 self.answer_name = answer_name 194 self.cache = {} 195 196 def dump(self, result_storage, result): 197 raise NotImplementedError 198 199 def get(self, ds_name, default=None): 200 raise NotImplementedError 201 202 203class AnswerTestCloudStorage(AnswerTestStorage): 204 def get(self, ds_name, default=None): 205 if self.reference_name is None: 206 return default 207 if ds_name in self.cache: 208 return self.cache[ds_name] 209 url = _url_path.format(self.reference_name, ds_name) 210 try: 211 resp = urllib.request.urlopen(url) 212 except urllib.error.HTTPError: 213 raise YTNoOldAnswer(url) 214 else: 215 for _ in range(3): 216 try: 217 data = resp.read() 218 except Exception: 219 time.sleep(0.01) 220 else: 221 # We were succesful 222 break 223 else: 224 # Raise error if all tries were unsuccessful 225 raise YTCloudError(url) 226 # This is dangerous, but we have a controlled S3 environment 227 rv = pickle.loads(data) 228 self.cache[ds_name] = rv 229 return rv 230 231 def progress_callback(self, current, total): 232 self.pbar.update(current) 233 234 def dump(self, result_storage): 235 if self.answer_name is None: 236 return 237 # This is where we dump our result storage up to Amazon, if we are able 238 # to. 239 import pyrax 240 241 credentials = os.path.expanduser(os.path.join("~", ".yt", "rackspace")) 242 pyrax.set_credential_file(credentials) 243 cf = pyrax.cloudfiles 244 c = cf.get_container("yt-answer-tests") 245 pb = get_pbar("Storing results ", len(result_storage)) 246 for i, ds_name in enumerate(result_storage): 247 pb.update(i + 1) 248 rs = pickle.dumps(result_storage[ds_name]) 249 object_name = f"{self.answer_name}_{ds_name}" 250 if object_name in c.get_object_names(): 251 obj = c.get_object(object_name) 252 c.delete_object(obj) 253 c.store_object(object_name, rs) 254 pb.finish() 255 256 257class AnswerTestLocalStorage(AnswerTestStorage): 258 def dump(self, result_storage): 259 # The 'tainted' attribute is automatically set to 'True' 260 # if the dataset required for an answer test is missing 261 # (see can_run_ds(). 262 # This logic check prevents creating a shelve with empty answers. 263 storage_is_tainted = result_storage.get("tainted", False) 264 if self.answer_name is None or storage_is_tainted: 265 return 266 # Store data using shelve 267 ds = shelve.open(self.answer_name, protocol=-1) 268 for ds_name in result_storage: 269 answer_name = f"{ds_name}" 270 if answer_name in ds: 271 mylog.info("Overwriting %s", answer_name) 272 ds[answer_name] = result_storage[ds_name] 273 ds.close() 274 275 def get(self, ds_name, default=None): 276 if self.reference_name is None: 277 return default 278 # Read data using shelve 279 answer_name = f"{ds_name}" 280 ds = shelve.open(self.reference_name, protocol=-1) 281 try: 282 result = ds[answer_name] 283 except KeyError: 284 result = default 285 ds.close() 286 return result 287 288 289@contextlib.contextmanager 290def temp_cwd(cwd): 291 oldcwd = os.getcwd() 292 os.chdir(cwd) 293 yield 294 os.chdir(oldcwd) 295 296 297def can_run_ds(ds_fn, file_check=False): 298 result_storage = AnswerTestingTest.result_storage 299 if isinstance(ds_fn, Dataset): 300 return result_storage is not None 301 path = ytcfg.get("yt", "test_data_dir") 302 if not os.path.isdir(path): 303 return False 304 if file_check: 305 return os.path.isfile(os.path.join(path, ds_fn)) and result_storage is not None 306 try: 307 load(ds_fn) 308 except FileNotFoundError: 309 if ytcfg.get("yt", "internals", "strict_requires"): 310 if result_storage is not None: 311 result_storage["tainted"] = True 312 raise 313 return False 314 return result_storage is not None 315 316 317def can_run_sim(sim_fn, sim_type, file_check=False): 318 issue_deprecation_warning( 319 "This function is no longer used in the " 320 "yt project testing framework and is " 321 "targeted for deprecation.", 322 since="4.0.0", 323 removal="4.1.0", 324 ) 325 result_storage = AnswerTestingTest.result_storage 326 if isinstance(sim_fn, SimulationTimeSeries): 327 return result_storage is not None 328 path = ytcfg.get("yt", "test_data_dir") 329 if not os.path.isdir(path): 330 return False 331 if file_check: 332 return os.path.isfile(os.path.join(path, sim_fn)) and result_storage is not None 333 try: 334 load_simulation(sim_fn, sim_type) 335 except FileNotFoundError: 336 if ytcfg.get("yt", "internals", "strict_requires"): 337 if result_storage is not None: 338 result_storage["tainted"] = True 339 raise 340 return False 341 return result_storage is not None 342 343 344def data_dir_load(ds_fn, cls=None, args=None, kwargs=None): 345 args = args or () 346 kwargs = kwargs or {} 347 path = ytcfg.get("yt", "test_data_dir") 348 if isinstance(ds_fn, Dataset): 349 return ds_fn 350 if not os.path.isdir(path): 351 return False 352 if cls is None: 353 ds = load(ds_fn, *args, **kwargs) 354 else: 355 ds = cls(os.path.join(path, ds_fn), *args, **kwargs) 356 ds.index 357 return ds 358 359 360def sim_dir_load(sim_fn, path=None, sim_type="Enzo", find_outputs=False): 361 if path is None and not os.path.exists(sim_fn): 362 raise OSError 363 if os.path.exists(sim_fn) or not path: 364 path = "." 365 return load_simulation( 366 os.path.join(path, sim_fn), sim_type, find_outputs=find_outputs 367 ) 368 369 370class AnswerTestingTest: 371 reference_storage = None 372 result_storage = None 373 prefix = "" 374 options = None 375 # This variable should be set if we are not providing `--answer-name` as 376 # command line parameter while running yt's answer testing using nosetests. 377 answer_name = None 378 379 def __init__(self, ds_fn): 380 if ds_fn is None: 381 self.ds = None 382 elif isinstance(ds_fn, Dataset): 383 self.ds = ds_fn 384 else: 385 self.ds = data_dir_load(ds_fn, kwargs={"unit_system": "code"}) 386 387 def __call__(self): 388 if AnswerTestingTest.result_storage is None: 389 return 390 nv = self.run() 391 392 # Test answer name should be provided either as command line parameters 393 # or by setting AnswerTestingTest.answer_name 394 if self.options.answer_name is None and self.answer_name is None: 395 raise YTNoAnswerNameSpecified() 396 397 # This is for running answer test when `--answer-name` is not set in 398 # nosetests command line arguments. In this case, set the answer_name 399 # from the `answer_name` keyword in the test case 400 if self.options.answer_name is None: 401 pyver = f"py{sys.version_info.major}{sys.version_info.minor}" 402 self.answer_name = f"{pyver}_{self.answer_name}" 403 404 answer_store_dir = os.path.realpath(self.options.output_dir) 405 ref_name = os.path.join( 406 answer_store_dir, self.answer_name, self.answer_name 407 ) 408 self.reference_storage.reference_name = ref_name 409 self.reference_storage.answer_name = ref_name 410 411 # If we are generating golden answers (passed --answer-store arg): 412 # - create the answer directory for this test 413 # - self.reference_storage.answer_name will be path to answer files 414 if self.options.store_results: 415 answer_test_dir = os.path.join(answer_store_dir, self.answer_name) 416 if not os.path.isdir(answer_test_dir): 417 os.makedirs(answer_test_dir) 418 self.reference_storage.reference_name = None 419 420 if self.reference_storage.reference_name is not None: 421 # Compare test generated values against the golden answer 422 dd = self.reference_storage.get(self.storage_name) 423 if dd is None or self.description not in dd: 424 raise YTNoOldAnswer(f"{self.storage_name} : {self.description}") 425 ov = dd[self.description] 426 self.compare(nv, ov) 427 else: 428 # Store results, hence do nothing (in case of --answer-store arg) 429 ov = None 430 self.result_storage[self.storage_name][self.description] = nv 431 432 @property 433 def storage_name(self): 434 if self.prefix != "": 435 return f"{self.prefix}_{self.ds}" 436 return str(self.ds) 437 438 def compare(self, new_result, old_result): 439 raise RuntimeError 440 441 def create_plot(self, ds, plot_type, plot_field, plot_axis, plot_kwargs=None): 442 # plot_type should be a string 443 # plot_kwargs should be a dict 444 if plot_type is None: 445 raise RuntimeError("Must explicitly request a plot type") 446 cls = getattr(pw, plot_type, None) 447 if cls is None: 448 cls = getattr(particle_plots, plot_type) 449 plot = cls(*(ds, plot_axis, plot_field), **plot_kwargs) 450 return plot 451 452 @property 453 def sim_center(self): 454 """ 455 This returns the center of the domain. 456 """ 457 return 0.5 * (self.ds.domain_right_edge + self.ds.domain_left_edge) 458 459 @property 460 def max_dens_location(self): 461 """ 462 This is a helper function to return the location of the most dense 463 point. 464 """ 465 return self.ds.find_max(("gas", "density"))[1] 466 467 @property 468 def entire_simulation(self): 469 """ 470 Return an unsorted array of values that cover the entire domain. 471 """ 472 return self.ds.all_data() 473 474 @property 475 def description(self): 476 obj_type = getattr(self, "obj_type", None) 477 if obj_type is None: 478 oname = "all" 479 else: 480 oname = "_".join(str(s) for s in obj_type) 481 args = [self._type_name, str(self.ds), oname] 482 args += [str(getattr(self, an)) for an in self._attrs] 483 suffix = getattr(self, "suffix", None) 484 if suffix: 485 args.append(suffix) 486 return "_".join(args).replace(".", "_") 487 488 489class FieldValuesTest(AnswerTestingTest): 490 _type_name = "FieldValues" 491 _attrs = ("field",) 492 493 def __init__(self, ds_fn, field, obj_type=None, particle_type=False, decimals=10): 494 super().__init__(ds_fn) 495 self.obj_type = obj_type 496 self.field = field 497 self.particle_type = particle_type 498 self.decimals = decimals 499 500 def run(self): 501 obj = create_obj(self.ds, self.obj_type) 502 field = obj._determine_fields(self.field)[0] 503 fd = self.ds.field_info[field] 504 if self.particle_type: 505 weight_field = (field[0], "particle_ones") 506 elif fd.is_sph_field: 507 weight_field = (field[0], "ones") 508 else: 509 weight_field = ("index", "ones") 510 avg = obj.quantities.weighted_average_quantity(field, weight=weight_field) 511 mi, ma = obj.quantities.extrema(self.field) 512 return [avg, mi, ma] 513 514 def compare(self, new_result, old_result): 515 err_msg = f"Field values for {self.field} not equal." 516 if hasattr(new_result, "d"): 517 new_result = new_result.d 518 if hasattr(old_result, "d"): 519 old_result = old_result.d 520 if self.decimals is None: 521 assert_equal(new_result, old_result, err_msg=err_msg, verbose=True) 522 else: 523 # What we do here is check if the old_result has units; if not, we 524 # assume they will be the same as the units of new_result. 525 if isinstance(old_result, np.ndarray) and not hasattr( 526 old_result, "in_units" 527 ): 528 # coerce it here to the same units 529 old_result = old_result * new_result[0].uq 530 assert_allclose_units( 531 new_result, 532 old_result, 533 10.0 ** (-self.decimals), 534 err_msg=err_msg, 535 verbose=True, 536 ) 537 538 539class AllFieldValuesTest(AnswerTestingTest): 540 _type_name = "AllFieldValues" 541 _attrs = ("field",) 542 543 def __init__(self, ds_fn, field, obj_type=None, decimals=None): 544 super().__init__(ds_fn) 545 self.obj_type = obj_type 546 self.field = field 547 self.decimals = decimals 548 549 def run(self): 550 obj = create_obj(self.ds, self.obj_type) 551 return obj[self.field] 552 553 def compare(self, new_result, old_result): 554 err_msg = f"All field values for {self.field} not equal." 555 if hasattr(new_result, "d"): 556 new_result = new_result.d 557 if hasattr(old_result, "d"): 558 old_result = old_result.d 559 if self.decimals is None: 560 assert_equal(new_result, old_result, err_msg=err_msg, verbose=True) 561 else: 562 assert_rel_equal( 563 new_result, old_result, self.decimals, err_msg=err_msg, verbose=True 564 ) 565 566 567class ProjectionValuesTest(AnswerTestingTest): 568 _type_name = "ProjectionValues" 569 _attrs = ("field", "axis", "weight_field") 570 571 def __init__( 572 self, ds_fn, axis, field, weight_field=None, obj_type=None, decimals=10 573 ): 574 super().__init__(ds_fn) 575 self.axis = axis 576 self.field = field 577 self.weight_field = weight_field 578 self.obj_type = obj_type 579 self.decimals = decimals 580 581 def run(self): 582 if self.obj_type is not None: 583 obj = create_obj(self.ds, self.obj_type) 584 else: 585 obj = None 586 if self.ds.domain_dimensions[self.axis] == 1: 587 return None 588 proj = self.ds.proj( 589 self.field, self.axis, weight_field=self.weight_field, data_source=obj 590 ) 591 return proj.field_data 592 593 def compare(self, new_result, old_result): 594 if new_result is None: 595 return 596 assert len(new_result) == len(old_result) 597 nind, oind = None, None 598 for k in new_result: 599 assert k in old_result 600 if oind is None: 601 oind = np.array(np.isnan(old_result[k])) 602 np.logical_or(oind, np.isnan(old_result[k]), oind) 603 if nind is None: 604 nind = np.array(np.isnan(new_result[k])) 605 np.logical_or(nind, np.isnan(new_result[k]), nind) 606 oind = ~oind 607 nind = ~nind 608 for k in new_result: 609 err_msg = ( 610 "%s values of %s (%s weighted) projection (axis %s) not equal." 611 % (k, self.field, self.weight_field, self.axis) 612 ) 613 if k == "weight_field": 614 # Our weight_field can vary between unit systems, whereas we 615 # can do a unitful comparison for the other fields. So we do 616 # not do the test here. 617 continue 618 nres, ores = new_result[k][nind], old_result[k][oind] 619 if hasattr(nres, "d"): 620 nres = nres.d 621 if hasattr(ores, "d"): 622 ores = ores.d 623 if self.decimals is None: 624 assert_equal(nres, ores, err_msg=err_msg) 625 else: 626 assert_allclose_units( 627 nres, ores, 10.0 ** -(self.decimals), err_msg=err_msg 628 ) 629 630 631class PixelizedProjectionValuesTest(AnswerTestingTest): 632 _type_name = "PixelizedProjectionValues" 633 _attrs = ("field", "axis", "weight_field") 634 635 def __init__(self, ds_fn, axis, field, weight_field=None, obj_type=None): 636 super().__init__(ds_fn) 637 self.axis = axis 638 self.field = field 639 self.weight_field = weight_field 640 self.obj_type = obj_type 641 642 def _get_frb(self, obj): 643 proj = self.ds.proj( 644 self.field, self.axis, weight_field=self.weight_field, data_source=obj 645 ) 646 frb = proj.to_frb((1.0, "unitary"), 256) 647 return proj, frb 648 649 def run(self): 650 if self.obj_type is not None: 651 obj = create_obj(self.ds, self.obj_type) 652 else: 653 obj = None 654 proj = self.ds.proj( 655 self.field, self.axis, weight_field=self.weight_field, data_source=obj 656 ) 657 frb = proj.to_frb((1.0, "unitary"), 256) 658 frb[self.field] 659 if self.weight_field is not None: 660 frb[self.weight_field] 661 d = frb.data 662 for f in proj.field_data: 663 # Sometimes f will be a tuple. 664 d[f"{f}_sum"] = proj.field_data[f].sum(dtype="float64") 665 return d 666 667 def compare(self, new_result, old_result): 668 assert len(new_result) == len(old_result) 669 for k in new_result: 670 assert k in old_result 671 for k in new_result: 672 # weight_field does not have units, so we do not directly compare them 673 if k == "weight_field_sum": 674 continue 675 try: 676 assert_allclose_units(new_result[k], old_result[k], 1e-10) 677 except AssertionError: 678 dump_images(new_result[k], old_result[k]) 679 raise 680 681 682class PixelizedParticleProjectionValuesTest(PixelizedProjectionValuesTest): 683 def _get_frb(self, obj): 684 proj_plot = particle_plots.ParticleProjectionPlot( 685 self.ds, self.axis, [self.field], weight_field=self.weight_field 686 ) 687 return proj_plot.data_source, proj_plot.frb 688 689 690class GridValuesTest(AnswerTestingTest): 691 _type_name = "GridValues" 692 _attrs = ("field",) 693 694 def __init__(self, ds_fn, field): 695 super().__init__(ds_fn) 696 self.field = field 697 698 def run(self): 699 hashes = {} 700 for g in self.ds.index.grids: 701 hashes[g.id] = hashlib.md5(g[self.field].tobytes()).hexdigest() 702 g.clear_data() 703 return hashes 704 705 def compare(self, new_result, old_result): 706 assert len(new_result) == len(old_result) 707 for k in new_result: 708 assert k in old_result 709 for k in new_result: 710 if hasattr(new_result[k], "d"): 711 new_result[k] = new_result[k].d 712 if hasattr(old_result[k], "d"): 713 old_result[k] = old_result[k].d 714 assert_equal(new_result[k], old_result[k]) 715 716 717class VerifySimulationSameTest(AnswerTestingTest): 718 _type_name = "VerifySimulationSame" 719 _attrs = () 720 721 def __init__(self, simulation_obj): 722 self.ds = simulation_obj 723 724 def run(self): 725 result = [ds.current_time for ds in self.ds] 726 return result 727 728 def compare(self, new_result, old_result): 729 assert_equal( 730 len(new_result), 731 len(old_result), 732 err_msg="Number of outputs not equal.", 733 verbose=True, 734 ) 735 for i in range(len(new_result)): 736 assert_equal( 737 new_result[i], 738 old_result[i], 739 err_msg="Output times not equal.", 740 verbose=True, 741 ) 742 743 744class GridHierarchyTest(AnswerTestingTest): 745 _type_name = "GridHierarchy" 746 _attrs = () 747 748 def run(self): 749 result = {} 750 result["grid_dimensions"] = self.ds.index.grid_dimensions 751 result["grid_left_edges"] = self.ds.index.grid_left_edge 752 result["grid_right_edges"] = self.ds.index.grid_right_edge 753 result["grid_levels"] = self.ds.index.grid_levels 754 result["grid_particle_count"] = self.ds.index.grid_particle_count 755 return result 756 757 def compare(self, new_result, old_result): 758 for k in new_result: 759 if hasattr(new_result[k], "d"): 760 new_result[k] = new_result[k].d 761 if hasattr(old_result[k], "d"): 762 old_result[k] = old_result[k].d 763 assert_equal(new_result[k], old_result[k]) 764 765 766class ParentageRelationshipsTest(AnswerTestingTest): 767 _type_name = "ParentageRelationships" 768 _attrs = () 769 770 def run(self): 771 result = {} 772 result["parents"] = [] 773 result["children"] = [] 774 for g in self.ds.index.grids: 775 p = g.Parent 776 if p is None: 777 result["parents"].append(None) 778 elif hasattr(p, "id"): 779 result["parents"].append(p.id) 780 else: 781 result["parents"].append([pg.id for pg in p]) 782 result["children"].append([c.id for c in g.Children]) 783 return result 784 785 def compare(self, new_result, old_result): 786 for newp, oldp in zip(new_result["parents"], old_result["parents"]): 787 assert newp == oldp 788 for newc, oldc in zip(new_result["children"], old_result["children"]): 789 assert newc == oldc 790 791 792def dump_images(new_result, old_result, decimals=10): 793 tmpfd, old_image = tempfile.mkstemp(suffix=".png") 794 os.close(tmpfd) 795 tmpfd, new_image = tempfile.mkstemp(suffix=".png") 796 os.close(tmpfd) 797 image_writer.write_projection(new_result, new_image) 798 image_writer.write_projection(old_result, old_image) 799 results = compare_images(old_image, new_image, 10 ** (-decimals)) 800 if results is not None: 801 tempfiles = [ 802 line.strip() for line in results.split("\n") if line.endswith(".png") 803 ] 804 for fn in tempfiles: 805 sys.stderr.write(f"\n[[ATTACHMENT|{fn}]]") 806 sys.stderr.write("\n") 807 808 809def compare_image_lists(new_result, old_result, decimals): 810 fns = [] 811 for _ in range(2): 812 tmpfd, tmpname = tempfile.mkstemp(suffix=".png") 813 os.close(tmpfd) 814 fns.append(tmpname) 815 num_images = len(old_result) 816 assert num_images > 0 817 for i in range(num_images): 818 mpimg.imsave(fns[0], np.loads(zlib.decompress(old_result[i]))) 819 mpimg.imsave(fns[1], np.loads(zlib.decompress(new_result[i]))) 820 results = compare_images(fns[0], fns[1], 10 ** (-decimals)) 821 if results is not None: 822 if os.environ.get("JENKINS_HOME") is not None: 823 tempfiles = [ 824 line.strip() 825 for line in results.split("\n") 826 if line.endswith(".png") 827 ] 828 for fn in tempfiles: 829 sys.stderr.write(f"\n[[ATTACHMENT|{fn}]]") 830 sys.stderr.write("\n") 831 assert_equal(results, None, results) 832 for fn in fns: 833 os.remove(fn) 834 835 836class VRImageComparisonTest(AnswerTestingTest): 837 _type_name = "VRImageComparison" 838 _attrs = ("desc",) 839 840 def __init__(self, scene, ds, desc, decimals): 841 super().__init__(None) 842 self.obj_type = ("vr",) 843 self.ds = ds 844 self.scene = scene 845 self.desc = desc 846 self.decimals = decimals 847 848 def run(self): 849 tmpfd, tmpname = tempfile.mkstemp(suffix=".png") 850 os.close(tmpfd) 851 self.scene.save(tmpname, sigma_clip=1.0) 852 image = mpimg.imread(tmpname) 853 os.remove(tmpname) 854 return [zlib.compress(image.dumps())] 855 856 def compare(self, new_result, old_result): 857 compare_image_lists(new_result, old_result, self.decimals) 858 859 860class PlotWindowAttributeTest(AnswerTestingTest): 861 _type_name = "PlotWindowAttribute" 862 _attrs = ( 863 "plot_type", 864 "plot_field", 865 "plot_axis", 866 "attr_name", 867 "attr_args", 868 "callback_id", 869 ) 870 871 def __init__( 872 self, 873 ds_fn, 874 plot_field, 875 plot_axis, 876 attr_name, 877 attr_args, 878 decimals, 879 plot_type="SlicePlot", 880 callback_id="", 881 callback_runners=None, 882 ): 883 super().__init__(ds_fn) 884 self.plot_type = plot_type 885 self.plot_field = plot_field 886 self.plot_axis = plot_axis 887 self.plot_kwargs = {} 888 self.attr_name = attr_name 889 self.attr_args = attr_args 890 self.decimals = decimals 891 # callback_id is so that we don't have to hash the actual callbacks 892 # run, but instead we call them something 893 self.callback_id = callback_id 894 if callback_runners is None: 895 callback_runners = [] 896 self.callback_runners = callback_runners 897 898 def run(self): 899 plot = self.create_plot( 900 self.ds, self.plot_type, self.plot_field, self.plot_axis, self.plot_kwargs 901 ) 902 for r in self.callback_runners: 903 r(self, plot) 904 attr = getattr(plot, self.attr_name) 905 attr(*self.attr_args[0], **self.attr_args[1]) 906 tmpfd, tmpname = tempfile.mkstemp(suffix=".png") 907 os.close(tmpfd) 908 plot.save(name=tmpname) 909 image = mpimg.imread(tmpname) 910 os.remove(tmpname) 911 return [zlib.compress(image.dumps())] 912 913 def compare(self, new_result, old_result): 914 compare_image_lists(new_result, old_result, self.decimals) 915 916 917class PhasePlotAttributeTest(AnswerTestingTest): 918 _type_name = "PhasePlotAttribute" 919 _attrs = ("plot_type", "x_field", "y_field", "z_field", "attr_name", "attr_args") 920 921 def __init__( 922 self, 923 ds_fn, 924 x_field, 925 y_field, 926 z_field, 927 attr_name, 928 attr_args, 929 decimals, 930 plot_type="PhasePlot", 931 ): 932 super().__init__(ds_fn) 933 self.data_source = self.ds.all_data() 934 self.plot_type = plot_type 935 self.x_field = x_field 936 self.y_field = y_field 937 self.z_field = z_field 938 self.plot_kwargs = {} 939 self.attr_name = attr_name 940 self.attr_args = attr_args 941 self.decimals = decimals 942 943 def create_plot( 944 self, data_source, x_field, y_field, z_field, plot_type, plot_kwargs=None 945 ): 946 # plot_type should be a string 947 # plot_kwargs should be a dict 948 if plot_type is None: 949 raise RuntimeError("Must explicitly request a plot type") 950 cls = getattr(profile_plotter, plot_type, None) 951 if cls is None: 952 cls = getattr(particle_plots, plot_type) 953 plot = cls(*(data_source, x_field, y_field, z_field), **plot_kwargs) 954 return plot 955 956 def run(self): 957 plot = self.create_plot( 958 self.data_source, 959 self.x_field, 960 self.y_field, 961 self.z_field, 962 self.plot_type, 963 self.plot_kwargs, 964 ) 965 attr = getattr(plot, self.attr_name) 966 attr(*self.attr_args[0], **self.attr_args[1]) 967 tmpfd, tmpname = tempfile.mkstemp(suffix=".png") 968 os.close(tmpfd) 969 plot.save(name=tmpname) 970 image = mpimg.imread(tmpname) 971 os.remove(tmpname) 972 return [zlib.compress(image.dumps())] 973 974 def compare(self, new_result, old_result): 975 compare_image_lists(new_result, old_result, self.decimals) 976 977 978class GenericArrayTest(AnswerTestingTest): 979 _type_name = "GenericArray" 980 _attrs = ("array_func_name", "args", "kwargs") 981 982 def __init__(self, ds_fn, array_func, args=None, kwargs=None, decimals=None): 983 super().__init__(ds_fn) 984 self.array_func = array_func 985 self.array_func_name = array_func.__name__ 986 self.args = args 987 self.kwargs = kwargs 988 self.decimals = decimals 989 990 def run(self): 991 if self.args is None: 992 args = [] 993 else: 994 args = self.args 995 if self.kwargs is None: 996 kwargs = {} 997 else: 998 kwargs = self.kwargs 999 return self.array_func(*args, **kwargs) 1000 1001 def compare(self, new_result, old_result): 1002 if not isinstance(new_result, dict): 1003 new_result = {"answer": new_result} 1004 old_result = {"answer": old_result} 1005 1006 assert_equal( 1007 len(new_result), 1008 len(old_result), 1009 err_msg="Number of outputs not equal.", 1010 verbose=True, 1011 ) 1012 for k in new_result: 1013 if hasattr(new_result[k], "d"): 1014 new_result[k] = new_result[k].d 1015 if hasattr(old_result[k], "d"): 1016 old_result[k] = old_result[k].d 1017 if self.decimals is None: 1018 assert_almost_equal(new_result[k], old_result[k]) 1019 else: 1020 assert_allclose_units( 1021 new_result[k], old_result[k], 10 ** (-self.decimals) 1022 ) 1023 1024 1025class GenericImageTest(AnswerTestingTest): 1026 _type_name = "GenericImage" 1027 _attrs = ("image_func_name", "args", "kwargs") 1028 1029 def __init__(self, ds_fn, image_func, decimals, args=None, kwargs=None): 1030 super().__init__(ds_fn) 1031 self.image_func = image_func 1032 self.image_func_name = image_func.__name__ 1033 self.args = args 1034 self.kwargs = kwargs 1035 self.decimals = decimals 1036 1037 def run(self): 1038 if self.args is None: 1039 args = [] 1040 else: 1041 args = self.args 1042 if self.kwargs is None: 1043 kwargs = {} 1044 else: 1045 kwargs = self.kwargs 1046 comp_imgs = [] 1047 tmpdir = tempfile.mkdtemp() 1048 image_prefix = os.path.join(tmpdir, "test_img") 1049 self.image_func(image_prefix, *args, **kwargs) 1050 imgs = sorted(glob.glob(image_prefix + "*")) 1051 assert len(imgs) > 0 1052 for img in imgs: 1053 img_data = mpimg.imread(img) 1054 os.remove(img) 1055 comp_imgs.append(zlib.compress(img_data.dumps())) 1056 return comp_imgs 1057 1058 def compare(self, new_result, old_result): 1059 compare_image_lists(new_result, old_result, self.decimals) 1060 1061 1062class AxialPixelizationTest(AnswerTestingTest): 1063 # This test is typically used once per geometry or coordinates type. 1064 # Feed it a dataset, and it checks that the results of basic pixelization 1065 # don't change. 1066 _type_name = "AxialPixelization" 1067 _attrs = ("geometry",) 1068 1069 def __init__(self, ds_fn, decimals=None): 1070 super().__init__(ds_fn) 1071 self.decimals = decimals 1072 self.geometry = self.ds.coordinates.name 1073 1074 def run(self): 1075 rv = {} 1076 ds = self.ds 1077 for i, axis in enumerate(ds.coordinates.axis_order): 1078 (bounds, center, display_center) = pw.get_window_parameters( 1079 axis, ds.domain_center, None, ds 1080 ) 1081 slc = ds.slice(axis, center[i]) 1082 xax = ds.coordinates.axis_name[ds.coordinates.x_axis[axis]] 1083 yax = ds.coordinates.axis_name[ds.coordinates.y_axis[axis]] 1084 pix_x = ds.coordinates.pixelize(axis, slc, ("gas", xax), bounds, (512, 512)) 1085 pix_y = ds.coordinates.pixelize(axis, slc, ("gas", yax), bounds, (512, 512)) 1086 # Wipe out invalid values (fillers) 1087 pix_x[~np.isfinite(pix_x)] = 0.0 1088 pix_y[~np.isfinite(pix_y)] = 0.0 1089 rv[f"{axis}_x"] = pix_x 1090 rv[f"{axis}_y"] = pix_y 1091 return rv 1092 1093 def compare(self, new_result, old_result): 1094 assert_equal( 1095 len(new_result), 1096 len(old_result), 1097 err_msg="Number of outputs not equal.", 1098 verbose=True, 1099 ) 1100 for k in new_result: 1101 if hasattr(new_result[k], "d"): 1102 new_result[k] = new_result[k].d 1103 if hasattr(old_result[k], "d"): 1104 old_result[k] = old_result[k].d 1105 if self.decimals is None: 1106 assert_almost_equal(new_result[k], old_result[k]) 1107 else: 1108 assert_allclose_units( 1109 new_result[k], old_result[k], 10 ** (-self.decimals) 1110 ) 1111 1112 1113def requires_sim(sim_fn, sim_type, big_data=False, file_check=False): 1114 issue_deprecation_warning( 1115 "This function is no longer used in the " 1116 "yt project testing framework and is " 1117 "targeted for deprecation.", 1118 since="4.0.0", 1119 removal="4.1.0", 1120 ) 1121 1122 from functools import wraps 1123 1124 from nose import SkipTest 1125 1126 def ffalse(func): 1127 @wraps(func) 1128 def fskip(*args, **kwargs): 1129 raise SkipTest 1130 1131 return fskip 1132 1133 def ftrue(func): 1134 return func 1135 1136 if not run_big_data and big_data: 1137 return ffalse 1138 elif not can_run_sim(sim_fn, sim_type, file_check): 1139 return ffalse 1140 else: 1141 return ftrue 1142 1143 1144def requires_answer_testing(): 1145 from functools import wraps 1146 1147 from nose import SkipTest 1148 1149 def ffalse(func): 1150 @wraps(func) 1151 def fskip(*args, **kwargs): 1152 raise SkipTest 1153 1154 return fskip 1155 1156 def ftrue(func): 1157 return func 1158 1159 if AnswerTestingTest.result_storage is not None: 1160 return ftrue 1161 else: 1162 return ffalse 1163 1164 1165def requires_ds(ds_fn, big_data=False, file_check=False): 1166 from functools import wraps 1167 1168 from nose import SkipTest 1169 1170 def ffalse(func): 1171 @wraps(func) 1172 def fskip(*args, **kwargs): 1173 raise SkipTest 1174 1175 return fskip 1176 1177 def ftrue(func): 1178 return func 1179 1180 if not run_big_data and big_data: 1181 return ffalse 1182 elif not can_run_ds(ds_fn, file_check): 1183 return ffalse 1184 else: 1185 return ftrue 1186 1187 1188def small_patch_amr(ds_fn, fields, input_center="max", input_weight=("gas", "density")): 1189 if not can_run_ds(ds_fn): 1190 return 1191 dso = [None, ("sphere", (input_center, (0.1, "unitary")))] 1192 yield GridHierarchyTest(ds_fn) 1193 yield ParentageRelationshipsTest(ds_fn) 1194 for field in fields: 1195 yield GridValuesTest(ds_fn, field) 1196 for axis in [0, 1, 2]: 1197 for dobj_name in dso: 1198 for weight_field in [None, input_weight]: 1199 yield ProjectionValuesTest( 1200 ds_fn, axis, field, weight_field, dobj_name 1201 ) 1202 yield FieldValuesTest(ds_fn, field, dobj_name) 1203 1204 1205def big_patch_amr(ds_fn, fields, input_center="max", input_weight=("gas", "density")): 1206 if not can_run_ds(ds_fn): 1207 return 1208 dso = [None, ("sphere", (input_center, (0.1, "unitary")))] 1209 yield GridHierarchyTest(ds_fn) 1210 yield ParentageRelationshipsTest(ds_fn) 1211 for field in fields: 1212 yield GridValuesTest(ds_fn, field) 1213 for axis in [0, 1, 2]: 1214 for dobj_name in dso: 1215 for weight_field in [None, input_weight]: 1216 yield PixelizedProjectionValuesTest( 1217 ds_fn, axis, field, weight_field, dobj_name 1218 ) 1219 1220 1221def _particle_answers( 1222 ds, ds_str_repr, ds_nparticles, fields, proj_test_class, center="c" 1223): 1224 if not can_run_ds(ds): 1225 return 1226 assert_equal(str(ds), ds_str_repr) 1227 dso = [None, ("sphere", (center, (0.1, "unitary")))] 1228 dd = ds.all_data() 1229 # this needs to explicitly be "all" 1230 assert_equal(dd["all", "particle_position"].shape, (ds_nparticles, 3)) 1231 tot = sum( 1232 dd[ptype, "particle_position"].shape[0] for ptype in ds.particle_types_raw 1233 ) 1234 assert_equal(tot, ds_nparticles) 1235 for dobj_name in dso: 1236 for field, weight_field in fields.items(): 1237 particle_type = field[0] in ds.particle_types 1238 for axis in [0, 1, 2]: 1239 if not particle_type: 1240 yield proj_test_class(ds, axis, field, weight_field, dobj_name) 1241 yield FieldValuesTest(ds, field, dobj_name, particle_type=particle_type) 1242 1243 1244def nbody_answer(ds, ds_str_repr, ds_nparticles, fields, center="c"): 1245 return _particle_answers( 1246 ds, 1247 ds_str_repr, 1248 ds_nparticles, 1249 fields, 1250 PixelizedParticleProjectionValuesTest, 1251 center=center, 1252 ) 1253 1254 1255def sph_answer(ds, ds_str_repr, ds_nparticles, fields, center="c"): 1256 return _particle_answers( 1257 ds, 1258 ds_str_repr, 1259 ds_nparticles, 1260 fields, 1261 PixelizedProjectionValuesTest, 1262 center=center, 1263 ) 1264 1265 1266def create_obj(ds, obj_type): 1267 # obj_type should be tuple of 1268 # ( obj_name, ( args ) ) 1269 if obj_type is None: 1270 return ds.all_data() 1271 cls = getattr(ds, obj_type[0]) 1272 obj = cls(*obj_type[1]) 1273 return obj 1274