1"""
2Title: framework.py
3Purpose: Contains answer tests that are used by yt's various frontends
4"""
5import contextlib
6import glob
7import hashlib
8import logging
9import os
10import pickle
11import shelve
12import sys
13import tempfile
14import time
15import urllib
16import zlib
17from collections import defaultdict
18
19import numpy as np
20from matplotlib import image as mpimg
21from matplotlib.testing.compare import compare_images
22from nose.plugins import Plugin
23
24from yt._maintenance.deprecation import issue_deprecation_warning
25from yt.config import ytcfg
26from yt.data_objects.static_output import Dataset
27from yt.data_objects.time_series import SimulationTimeSeries
28from yt.funcs import get_pbar, get_yt_version
29from yt.loaders import load, load_simulation
30from yt.testing import (
31    assert_allclose_units,
32    assert_almost_equal,
33    assert_equal,
34    assert_rel_equal,
35)
36from yt.utilities.exceptions import YTCloudError, YTNoAnswerNameSpecified, YTNoOldAnswer
37from yt.utilities.logger import disable_stream_logging
38from yt.visualization import (
39    image_writer as image_writer,
40    particle_plots as particle_plots,
41    plot_window as pw,
42    profile_plotter as profile_plotter,
43)
44
45mylog = logging.getLogger("nose.plugins.answer-testing")
46run_big_data = False
47
48# Set the latest gold and local standard filenames
49_latest = ytcfg.get("yt", "gold_standard_filename")
50_latest_local = ytcfg.get("yt", "local_standard_filename")
51_url_path = ytcfg.get("yt", "answer_tests_url")
52
53
54class AnswerTesting(Plugin):
55    name = "answer-testing"
56    _my_version = None
57
58    def options(self, parser, env=os.environ):
59        super().options(parser, env=env)
60        parser.add_option(
61            "--answer-name",
62            dest="answer_name",
63            metavar="str",
64            default=None,
65            help="The name of the standard to store/compare against",
66        )
67        parser.add_option(
68            "--answer-store",
69            dest="store_results",
70            metavar="bool",
71            default=False,
72            action="store_true",
73            help="Should we store this result instead of comparing?",
74        )
75        parser.add_option(
76            "--local",
77            dest="local_results",
78            default=False,
79            action="store_true",
80            help="Store/load reference results locally?",
81        )
82        parser.add_option(
83            "--answer-big-data",
84            dest="big_data",
85            default=False,
86            help="Should we run against big data, too?",
87            action="store_true",
88        )
89        parser.add_option(
90            "--local-dir",
91            dest="output_dir",
92            metavar="str",
93            help="The name of the directory to store local results",
94        )
95
96    @property
97    def my_version(self, version=None):
98        if self._my_version is not None:
99            return self._my_version
100        if version is None:
101            try:
102                version = get_yt_version()
103            except Exception:
104                version = f"UNKNOWN{time.time()}"
105        self._my_version = version
106        return self._my_version
107
108    def configure(self, options, conf):
109        super().configure(options, conf)
110        if not self.enabled:
111            return
112        disable_stream_logging()
113
114        # Parse through the storage flags to make sense of them
115        # and use reasonable defaults
116        # If we're storing the data, default storage name is local
117        # latest version
118        if options.store_results:
119            if options.answer_name is None:
120                self.store_name = _latest_local
121            else:
122                self.store_name = options.answer_name
123            self.compare_name = None
124        # if we're not storing, then we're comparing, and we want default
125        # comparison name to be the latest gold standard
126        # either on network or local
127        else:
128            if options.answer_name is None:
129                if options.local_results:
130                    self.compare_name = _latest_local
131                else:
132                    self.compare_name = _latest
133            else:
134                self.compare_name = options.answer_name
135            self.store_name = self.my_version
136
137        self.store_results = options.store_results
138
139        ytcfg["yt", "internals", "within_testing"] = True
140        AnswerTestingTest.result_storage = self.result_storage = defaultdict(dict)
141        if self.compare_name == "SKIP":
142            self.compare_name = None
143        elif self.compare_name == "latest":
144            self.compare_name = _latest
145
146        # Local/Cloud storage
147        if options.local_results:
148            if options.output_dir is None:
149                print("Please supply an output directory with the --local-dir option")
150                sys.exit(1)
151            storage_class = AnswerTestLocalStorage
152            output_dir = os.path.realpath(options.output_dir)
153            # Fix up filename for local storage
154            if self.compare_name is not None:
155                self.compare_name = os.path.join(
156                    output_dir, self.compare_name, self.compare_name
157                )
158
159            # Create a local directory only when `options.answer_name` is
160            # provided. If it is not provided then creating local directory
161            # will depend on the `AnswerTestingTest.answer_name` value of the
162            # test, this case is handled in AnswerTestingTest class.
163            if options.store_results and options.answer_name is not None:
164                name_dir_path = os.path.join(output_dir, self.store_name)
165                if not os.path.isdir(name_dir_path):
166                    os.makedirs(name_dir_path)
167                self.store_name = os.path.join(name_dir_path, self.store_name)
168        else:
169            storage_class = AnswerTestCloudStorage
170
171        # Initialize answer/reference storage
172        AnswerTestingTest.reference_storage = self.storage = storage_class(
173            self.compare_name, self.store_name
174        )
175        AnswerTestingTest.options = options
176
177        self.local_results = options.local_results
178        global run_big_data
179        run_big_data = options.big_data
180
181    def finalize(self, result=None):
182        if not self.store_results:
183            return
184        self.storage.dump(self.result_storage)
185
186    def help(self):
187        return "yt answer testing support"
188
189
190class AnswerTestStorage:
191    def __init__(self, reference_name=None, answer_name=None):
192        self.reference_name = reference_name
193        self.answer_name = answer_name
194        self.cache = {}
195
196    def dump(self, result_storage, result):
197        raise NotImplementedError
198
199    def get(self, ds_name, default=None):
200        raise NotImplementedError
201
202
203class AnswerTestCloudStorage(AnswerTestStorage):
204    def get(self, ds_name, default=None):
205        if self.reference_name is None:
206            return default
207        if ds_name in self.cache:
208            return self.cache[ds_name]
209        url = _url_path.format(self.reference_name, ds_name)
210        try:
211            resp = urllib.request.urlopen(url)
212        except urllib.error.HTTPError:
213            raise YTNoOldAnswer(url)
214        else:
215            for _ in range(3):
216                try:
217                    data = resp.read()
218                except Exception:
219                    time.sleep(0.01)
220                else:
221                    # We were succesful
222                    break
223            else:
224                # Raise error if all tries were unsuccessful
225                raise YTCloudError(url)
226            # This is dangerous, but we have a controlled S3 environment
227            rv = pickle.loads(data)
228        self.cache[ds_name] = rv
229        return rv
230
231    def progress_callback(self, current, total):
232        self.pbar.update(current)
233
234    def dump(self, result_storage):
235        if self.answer_name is None:
236            return
237        # This is where we dump our result storage up to Amazon, if we are able
238        # to.
239        import pyrax
240
241        credentials = os.path.expanduser(os.path.join("~", ".yt", "rackspace"))
242        pyrax.set_credential_file(credentials)
243        cf = pyrax.cloudfiles
244        c = cf.get_container("yt-answer-tests")
245        pb = get_pbar("Storing results ", len(result_storage))
246        for i, ds_name in enumerate(result_storage):
247            pb.update(i + 1)
248            rs = pickle.dumps(result_storage[ds_name])
249            object_name = f"{self.answer_name}_{ds_name}"
250            if object_name in c.get_object_names():
251                obj = c.get_object(object_name)
252                c.delete_object(obj)
253            c.store_object(object_name, rs)
254        pb.finish()
255
256
257class AnswerTestLocalStorage(AnswerTestStorage):
258    def dump(self, result_storage):
259        # The 'tainted' attribute is automatically set to 'True'
260        # if the dataset required for an answer test is missing
261        # (see can_run_ds().
262        # This logic check prevents creating a shelve with empty answers.
263        storage_is_tainted = result_storage.get("tainted", False)
264        if self.answer_name is None or storage_is_tainted:
265            return
266        # Store data using shelve
267        ds = shelve.open(self.answer_name, protocol=-1)
268        for ds_name in result_storage:
269            answer_name = f"{ds_name}"
270            if answer_name in ds:
271                mylog.info("Overwriting %s", answer_name)
272            ds[answer_name] = result_storage[ds_name]
273        ds.close()
274
275    def get(self, ds_name, default=None):
276        if self.reference_name is None:
277            return default
278        # Read data using shelve
279        answer_name = f"{ds_name}"
280        ds = shelve.open(self.reference_name, protocol=-1)
281        try:
282            result = ds[answer_name]
283        except KeyError:
284            result = default
285        ds.close()
286        return result
287
288
289@contextlib.contextmanager
290def temp_cwd(cwd):
291    oldcwd = os.getcwd()
292    os.chdir(cwd)
293    yield
294    os.chdir(oldcwd)
295
296
297def can_run_ds(ds_fn, file_check=False):
298    result_storage = AnswerTestingTest.result_storage
299    if isinstance(ds_fn, Dataset):
300        return result_storage is not None
301    path = ytcfg.get("yt", "test_data_dir")
302    if not os.path.isdir(path):
303        return False
304    if file_check:
305        return os.path.isfile(os.path.join(path, ds_fn)) and result_storage is not None
306    try:
307        load(ds_fn)
308    except FileNotFoundError:
309        if ytcfg.get("yt", "internals", "strict_requires"):
310            if result_storage is not None:
311                result_storage["tainted"] = True
312            raise
313        return False
314    return result_storage is not None
315
316
317def can_run_sim(sim_fn, sim_type, file_check=False):
318    issue_deprecation_warning(
319        "This function is no longer used in the "
320        "yt project testing framework and is "
321        "targeted for deprecation.",
322        since="4.0.0",
323        removal="4.1.0",
324    )
325    result_storage = AnswerTestingTest.result_storage
326    if isinstance(sim_fn, SimulationTimeSeries):
327        return result_storage is not None
328    path = ytcfg.get("yt", "test_data_dir")
329    if not os.path.isdir(path):
330        return False
331    if file_check:
332        return os.path.isfile(os.path.join(path, sim_fn)) and result_storage is not None
333    try:
334        load_simulation(sim_fn, sim_type)
335    except FileNotFoundError:
336        if ytcfg.get("yt", "internals", "strict_requires"):
337            if result_storage is not None:
338                result_storage["tainted"] = True
339            raise
340        return False
341    return result_storage is not None
342
343
344def data_dir_load(ds_fn, cls=None, args=None, kwargs=None):
345    args = args or ()
346    kwargs = kwargs or {}
347    path = ytcfg.get("yt", "test_data_dir")
348    if isinstance(ds_fn, Dataset):
349        return ds_fn
350    if not os.path.isdir(path):
351        return False
352    if cls is None:
353        ds = load(ds_fn, *args, **kwargs)
354    else:
355        ds = cls(os.path.join(path, ds_fn), *args, **kwargs)
356    ds.index
357    return ds
358
359
360def sim_dir_load(sim_fn, path=None, sim_type="Enzo", find_outputs=False):
361    if path is None and not os.path.exists(sim_fn):
362        raise OSError
363    if os.path.exists(sim_fn) or not path:
364        path = "."
365    return load_simulation(
366        os.path.join(path, sim_fn), sim_type, find_outputs=find_outputs
367    )
368
369
370class AnswerTestingTest:
371    reference_storage = None
372    result_storage = None
373    prefix = ""
374    options = None
375    # This variable should be set if we are not providing `--answer-name` as
376    # command line parameter while running yt's answer testing using nosetests.
377    answer_name = None
378
379    def __init__(self, ds_fn):
380        if ds_fn is None:
381            self.ds = None
382        elif isinstance(ds_fn, Dataset):
383            self.ds = ds_fn
384        else:
385            self.ds = data_dir_load(ds_fn, kwargs={"unit_system": "code"})
386
387    def __call__(self):
388        if AnswerTestingTest.result_storage is None:
389            return
390        nv = self.run()
391
392        # Test answer name should be provided either as command line parameters
393        # or by setting AnswerTestingTest.answer_name
394        if self.options.answer_name is None and self.answer_name is None:
395            raise YTNoAnswerNameSpecified()
396
397        # This is for running answer test when `--answer-name` is not set in
398        # nosetests command line arguments. In this case, set the answer_name
399        # from the `answer_name` keyword in the test case
400        if self.options.answer_name is None:
401            pyver = f"py{sys.version_info.major}{sys.version_info.minor}"
402            self.answer_name = f"{pyver}_{self.answer_name}"
403
404            answer_store_dir = os.path.realpath(self.options.output_dir)
405            ref_name = os.path.join(
406                answer_store_dir, self.answer_name, self.answer_name
407            )
408            self.reference_storage.reference_name = ref_name
409            self.reference_storage.answer_name = ref_name
410
411            # If we are generating golden answers (passed --answer-store arg):
412            # - create the answer directory for this test
413            # - self.reference_storage.answer_name will be path to answer files
414            if self.options.store_results:
415                answer_test_dir = os.path.join(answer_store_dir, self.answer_name)
416                if not os.path.isdir(answer_test_dir):
417                    os.makedirs(answer_test_dir)
418                self.reference_storage.reference_name = None
419
420        if self.reference_storage.reference_name is not None:
421            # Compare test generated values against the golden answer
422            dd = self.reference_storage.get(self.storage_name)
423            if dd is None or self.description not in dd:
424                raise YTNoOldAnswer(f"{self.storage_name} : {self.description}")
425            ov = dd[self.description]
426            self.compare(nv, ov)
427        else:
428            # Store results, hence do nothing (in case of --answer-store arg)
429            ov = None
430        self.result_storage[self.storage_name][self.description] = nv
431
432    @property
433    def storage_name(self):
434        if self.prefix != "":
435            return f"{self.prefix}_{self.ds}"
436        return str(self.ds)
437
438    def compare(self, new_result, old_result):
439        raise RuntimeError
440
441    def create_plot(self, ds, plot_type, plot_field, plot_axis, plot_kwargs=None):
442        # plot_type should be a string
443        # plot_kwargs should be a dict
444        if plot_type is None:
445            raise RuntimeError("Must explicitly request a plot type")
446        cls = getattr(pw, plot_type, None)
447        if cls is None:
448            cls = getattr(particle_plots, plot_type)
449        plot = cls(*(ds, plot_axis, plot_field), **plot_kwargs)
450        return plot
451
452    @property
453    def sim_center(self):
454        """
455        This returns the center of the domain.
456        """
457        return 0.5 * (self.ds.domain_right_edge + self.ds.domain_left_edge)
458
459    @property
460    def max_dens_location(self):
461        """
462        This is a helper function to return the location of the most dense
463        point.
464        """
465        return self.ds.find_max(("gas", "density"))[1]
466
467    @property
468    def entire_simulation(self):
469        """
470        Return an unsorted array of values that cover the entire domain.
471        """
472        return self.ds.all_data()
473
474    @property
475    def description(self):
476        obj_type = getattr(self, "obj_type", None)
477        if obj_type is None:
478            oname = "all"
479        else:
480            oname = "_".join(str(s) for s in obj_type)
481        args = [self._type_name, str(self.ds), oname]
482        args += [str(getattr(self, an)) for an in self._attrs]
483        suffix = getattr(self, "suffix", None)
484        if suffix:
485            args.append(suffix)
486        return "_".join(args).replace(".", "_")
487
488
489class FieldValuesTest(AnswerTestingTest):
490    _type_name = "FieldValues"
491    _attrs = ("field",)
492
493    def __init__(self, ds_fn, field, obj_type=None, particle_type=False, decimals=10):
494        super().__init__(ds_fn)
495        self.obj_type = obj_type
496        self.field = field
497        self.particle_type = particle_type
498        self.decimals = decimals
499
500    def run(self):
501        obj = create_obj(self.ds, self.obj_type)
502        field = obj._determine_fields(self.field)[0]
503        fd = self.ds.field_info[field]
504        if self.particle_type:
505            weight_field = (field[0], "particle_ones")
506        elif fd.is_sph_field:
507            weight_field = (field[0], "ones")
508        else:
509            weight_field = ("index", "ones")
510        avg = obj.quantities.weighted_average_quantity(field, weight=weight_field)
511        mi, ma = obj.quantities.extrema(self.field)
512        return [avg, mi, ma]
513
514    def compare(self, new_result, old_result):
515        err_msg = f"Field values for {self.field} not equal."
516        if hasattr(new_result, "d"):
517            new_result = new_result.d
518        if hasattr(old_result, "d"):
519            old_result = old_result.d
520        if self.decimals is None:
521            assert_equal(new_result, old_result, err_msg=err_msg, verbose=True)
522        else:
523            # What we do here is check if the old_result has units; if not, we
524            # assume they will be the same as the units of new_result.
525            if isinstance(old_result, np.ndarray) and not hasattr(
526                old_result, "in_units"
527            ):
528                # coerce it here to the same units
529                old_result = old_result * new_result[0].uq
530            assert_allclose_units(
531                new_result,
532                old_result,
533                10.0 ** (-self.decimals),
534                err_msg=err_msg,
535                verbose=True,
536            )
537
538
539class AllFieldValuesTest(AnswerTestingTest):
540    _type_name = "AllFieldValues"
541    _attrs = ("field",)
542
543    def __init__(self, ds_fn, field, obj_type=None, decimals=None):
544        super().__init__(ds_fn)
545        self.obj_type = obj_type
546        self.field = field
547        self.decimals = decimals
548
549    def run(self):
550        obj = create_obj(self.ds, self.obj_type)
551        return obj[self.field]
552
553    def compare(self, new_result, old_result):
554        err_msg = f"All field values for {self.field} not equal."
555        if hasattr(new_result, "d"):
556            new_result = new_result.d
557        if hasattr(old_result, "d"):
558            old_result = old_result.d
559        if self.decimals is None:
560            assert_equal(new_result, old_result, err_msg=err_msg, verbose=True)
561        else:
562            assert_rel_equal(
563                new_result, old_result, self.decimals, err_msg=err_msg, verbose=True
564            )
565
566
567class ProjectionValuesTest(AnswerTestingTest):
568    _type_name = "ProjectionValues"
569    _attrs = ("field", "axis", "weight_field")
570
571    def __init__(
572        self, ds_fn, axis, field, weight_field=None, obj_type=None, decimals=10
573    ):
574        super().__init__(ds_fn)
575        self.axis = axis
576        self.field = field
577        self.weight_field = weight_field
578        self.obj_type = obj_type
579        self.decimals = decimals
580
581    def run(self):
582        if self.obj_type is not None:
583            obj = create_obj(self.ds, self.obj_type)
584        else:
585            obj = None
586        if self.ds.domain_dimensions[self.axis] == 1:
587            return None
588        proj = self.ds.proj(
589            self.field, self.axis, weight_field=self.weight_field, data_source=obj
590        )
591        return proj.field_data
592
593    def compare(self, new_result, old_result):
594        if new_result is None:
595            return
596        assert len(new_result) == len(old_result)
597        nind, oind = None, None
598        for k in new_result:
599            assert k in old_result
600            if oind is None:
601                oind = np.array(np.isnan(old_result[k]))
602            np.logical_or(oind, np.isnan(old_result[k]), oind)
603            if nind is None:
604                nind = np.array(np.isnan(new_result[k]))
605            np.logical_or(nind, np.isnan(new_result[k]), nind)
606        oind = ~oind
607        nind = ~nind
608        for k in new_result:
609            err_msg = (
610                "%s values of %s (%s weighted) projection (axis %s) not equal."
611                % (k, self.field, self.weight_field, self.axis)
612            )
613            if k == "weight_field":
614                # Our weight_field can vary between unit systems, whereas we
615                # can do a unitful comparison for the other fields.  So we do
616                # not do the test here.
617                continue
618            nres, ores = new_result[k][nind], old_result[k][oind]
619            if hasattr(nres, "d"):
620                nres = nres.d
621            if hasattr(ores, "d"):
622                ores = ores.d
623            if self.decimals is None:
624                assert_equal(nres, ores, err_msg=err_msg)
625            else:
626                assert_allclose_units(
627                    nres, ores, 10.0 ** -(self.decimals), err_msg=err_msg
628                )
629
630
631class PixelizedProjectionValuesTest(AnswerTestingTest):
632    _type_name = "PixelizedProjectionValues"
633    _attrs = ("field", "axis", "weight_field")
634
635    def __init__(self, ds_fn, axis, field, weight_field=None, obj_type=None):
636        super().__init__(ds_fn)
637        self.axis = axis
638        self.field = field
639        self.weight_field = weight_field
640        self.obj_type = obj_type
641
642    def _get_frb(self, obj):
643        proj = self.ds.proj(
644            self.field, self.axis, weight_field=self.weight_field, data_source=obj
645        )
646        frb = proj.to_frb((1.0, "unitary"), 256)
647        return proj, frb
648
649    def run(self):
650        if self.obj_type is not None:
651            obj = create_obj(self.ds, self.obj_type)
652        else:
653            obj = None
654        proj = self.ds.proj(
655            self.field, self.axis, weight_field=self.weight_field, data_source=obj
656        )
657        frb = proj.to_frb((1.0, "unitary"), 256)
658        frb[self.field]
659        if self.weight_field is not None:
660            frb[self.weight_field]
661        d = frb.data
662        for f in proj.field_data:
663            # Sometimes f will be a tuple.
664            d[f"{f}_sum"] = proj.field_data[f].sum(dtype="float64")
665        return d
666
667    def compare(self, new_result, old_result):
668        assert len(new_result) == len(old_result)
669        for k in new_result:
670            assert k in old_result
671        for k in new_result:
672            # weight_field does not have units, so we do not directly compare them
673            if k == "weight_field_sum":
674                continue
675            try:
676                assert_allclose_units(new_result[k], old_result[k], 1e-10)
677            except AssertionError:
678                dump_images(new_result[k], old_result[k])
679                raise
680
681
682class PixelizedParticleProjectionValuesTest(PixelizedProjectionValuesTest):
683    def _get_frb(self, obj):
684        proj_plot = particle_plots.ParticleProjectionPlot(
685            self.ds, self.axis, [self.field], weight_field=self.weight_field
686        )
687        return proj_plot.data_source, proj_plot.frb
688
689
690class GridValuesTest(AnswerTestingTest):
691    _type_name = "GridValues"
692    _attrs = ("field",)
693
694    def __init__(self, ds_fn, field):
695        super().__init__(ds_fn)
696        self.field = field
697
698    def run(self):
699        hashes = {}
700        for g in self.ds.index.grids:
701            hashes[g.id] = hashlib.md5(g[self.field].tobytes()).hexdigest()
702            g.clear_data()
703        return hashes
704
705    def compare(self, new_result, old_result):
706        assert len(new_result) == len(old_result)
707        for k in new_result:
708            assert k in old_result
709        for k in new_result:
710            if hasattr(new_result[k], "d"):
711                new_result[k] = new_result[k].d
712            if hasattr(old_result[k], "d"):
713                old_result[k] = old_result[k].d
714            assert_equal(new_result[k], old_result[k])
715
716
717class VerifySimulationSameTest(AnswerTestingTest):
718    _type_name = "VerifySimulationSame"
719    _attrs = ()
720
721    def __init__(self, simulation_obj):
722        self.ds = simulation_obj
723
724    def run(self):
725        result = [ds.current_time for ds in self.ds]
726        return result
727
728    def compare(self, new_result, old_result):
729        assert_equal(
730            len(new_result),
731            len(old_result),
732            err_msg="Number of outputs not equal.",
733            verbose=True,
734        )
735        for i in range(len(new_result)):
736            assert_equal(
737                new_result[i],
738                old_result[i],
739                err_msg="Output times not equal.",
740                verbose=True,
741            )
742
743
744class GridHierarchyTest(AnswerTestingTest):
745    _type_name = "GridHierarchy"
746    _attrs = ()
747
748    def run(self):
749        result = {}
750        result["grid_dimensions"] = self.ds.index.grid_dimensions
751        result["grid_left_edges"] = self.ds.index.grid_left_edge
752        result["grid_right_edges"] = self.ds.index.grid_right_edge
753        result["grid_levels"] = self.ds.index.grid_levels
754        result["grid_particle_count"] = self.ds.index.grid_particle_count
755        return result
756
757    def compare(self, new_result, old_result):
758        for k in new_result:
759            if hasattr(new_result[k], "d"):
760                new_result[k] = new_result[k].d
761            if hasattr(old_result[k], "d"):
762                old_result[k] = old_result[k].d
763            assert_equal(new_result[k], old_result[k])
764
765
766class ParentageRelationshipsTest(AnswerTestingTest):
767    _type_name = "ParentageRelationships"
768    _attrs = ()
769
770    def run(self):
771        result = {}
772        result["parents"] = []
773        result["children"] = []
774        for g in self.ds.index.grids:
775            p = g.Parent
776            if p is None:
777                result["parents"].append(None)
778            elif hasattr(p, "id"):
779                result["parents"].append(p.id)
780            else:
781                result["parents"].append([pg.id for pg in p])
782            result["children"].append([c.id for c in g.Children])
783        return result
784
785    def compare(self, new_result, old_result):
786        for newp, oldp in zip(new_result["parents"], old_result["parents"]):
787            assert newp == oldp
788        for newc, oldc in zip(new_result["children"], old_result["children"]):
789            assert newc == oldc
790
791
792def dump_images(new_result, old_result, decimals=10):
793    tmpfd, old_image = tempfile.mkstemp(suffix=".png")
794    os.close(tmpfd)
795    tmpfd, new_image = tempfile.mkstemp(suffix=".png")
796    os.close(tmpfd)
797    image_writer.write_projection(new_result, new_image)
798    image_writer.write_projection(old_result, old_image)
799    results = compare_images(old_image, new_image, 10 ** (-decimals))
800    if results is not None:
801        tempfiles = [
802            line.strip() for line in results.split("\n") if line.endswith(".png")
803        ]
804        for fn in tempfiles:
805            sys.stderr.write(f"\n[[ATTACHMENT|{fn}]]")
806        sys.stderr.write("\n")
807
808
809def compare_image_lists(new_result, old_result, decimals):
810    fns = []
811    for _ in range(2):
812        tmpfd, tmpname = tempfile.mkstemp(suffix=".png")
813        os.close(tmpfd)
814        fns.append(tmpname)
815    num_images = len(old_result)
816    assert num_images > 0
817    for i in range(num_images):
818        mpimg.imsave(fns[0], np.loads(zlib.decompress(old_result[i])))
819        mpimg.imsave(fns[1], np.loads(zlib.decompress(new_result[i])))
820        results = compare_images(fns[0], fns[1], 10 ** (-decimals))
821        if results is not None:
822            if os.environ.get("JENKINS_HOME") is not None:
823                tempfiles = [
824                    line.strip()
825                    for line in results.split("\n")
826                    if line.endswith(".png")
827                ]
828                for fn in tempfiles:
829                    sys.stderr.write(f"\n[[ATTACHMENT|{fn}]]")
830                sys.stderr.write("\n")
831        assert_equal(results, None, results)
832        for fn in fns:
833            os.remove(fn)
834
835
836class VRImageComparisonTest(AnswerTestingTest):
837    _type_name = "VRImageComparison"
838    _attrs = ("desc",)
839
840    def __init__(self, scene, ds, desc, decimals):
841        super().__init__(None)
842        self.obj_type = ("vr",)
843        self.ds = ds
844        self.scene = scene
845        self.desc = desc
846        self.decimals = decimals
847
848    def run(self):
849        tmpfd, tmpname = tempfile.mkstemp(suffix=".png")
850        os.close(tmpfd)
851        self.scene.save(tmpname, sigma_clip=1.0)
852        image = mpimg.imread(tmpname)
853        os.remove(tmpname)
854        return [zlib.compress(image.dumps())]
855
856    def compare(self, new_result, old_result):
857        compare_image_lists(new_result, old_result, self.decimals)
858
859
860class PlotWindowAttributeTest(AnswerTestingTest):
861    _type_name = "PlotWindowAttribute"
862    _attrs = (
863        "plot_type",
864        "plot_field",
865        "plot_axis",
866        "attr_name",
867        "attr_args",
868        "callback_id",
869    )
870
871    def __init__(
872        self,
873        ds_fn,
874        plot_field,
875        plot_axis,
876        attr_name,
877        attr_args,
878        decimals,
879        plot_type="SlicePlot",
880        callback_id="",
881        callback_runners=None,
882    ):
883        super().__init__(ds_fn)
884        self.plot_type = plot_type
885        self.plot_field = plot_field
886        self.plot_axis = plot_axis
887        self.plot_kwargs = {}
888        self.attr_name = attr_name
889        self.attr_args = attr_args
890        self.decimals = decimals
891        # callback_id is so that we don't have to hash the actual callbacks
892        # run, but instead we call them something
893        self.callback_id = callback_id
894        if callback_runners is None:
895            callback_runners = []
896        self.callback_runners = callback_runners
897
898    def run(self):
899        plot = self.create_plot(
900            self.ds, self.plot_type, self.plot_field, self.plot_axis, self.plot_kwargs
901        )
902        for r in self.callback_runners:
903            r(self, plot)
904        attr = getattr(plot, self.attr_name)
905        attr(*self.attr_args[0], **self.attr_args[1])
906        tmpfd, tmpname = tempfile.mkstemp(suffix=".png")
907        os.close(tmpfd)
908        plot.save(name=tmpname)
909        image = mpimg.imread(tmpname)
910        os.remove(tmpname)
911        return [zlib.compress(image.dumps())]
912
913    def compare(self, new_result, old_result):
914        compare_image_lists(new_result, old_result, self.decimals)
915
916
917class PhasePlotAttributeTest(AnswerTestingTest):
918    _type_name = "PhasePlotAttribute"
919    _attrs = ("plot_type", "x_field", "y_field", "z_field", "attr_name", "attr_args")
920
921    def __init__(
922        self,
923        ds_fn,
924        x_field,
925        y_field,
926        z_field,
927        attr_name,
928        attr_args,
929        decimals,
930        plot_type="PhasePlot",
931    ):
932        super().__init__(ds_fn)
933        self.data_source = self.ds.all_data()
934        self.plot_type = plot_type
935        self.x_field = x_field
936        self.y_field = y_field
937        self.z_field = z_field
938        self.plot_kwargs = {}
939        self.attr_name = attr_name
940        self.attr_args = attr_args
941        self.decimals = decimals
942
943    def create_plot(
944        self, data_source, x_field, y_field, z_field, plot_type, plot_kwargs=None
945    ):
946        # plot_type should be a string
947        # plot_kwargs should be a dict
948        if plot_type is None:
949            raise RuntimeError("Must explicitly request a plot type")
950        cls = getattr(profile_plotter, plot_type, None)
951        if cls is None:
952            cls = getattr(particle_plots, plot_type)
953        plot = cls(*(data_source, x_field, y_field, z_field), **plot_kwargs)
954        return plot
955
956    def run(self):
957        plot = self.create_plot(
958            self.data_source,
959            self.x_field,
960            self.y_field,
961            self.z_field,
962            self.plot_type,
963            self.plot_kwargs,
964        )
965        attr = getattr(plot, self.attr_name)
966        attr(*self.attr_args[0], **self.attr_args[1])
967        tmpfd, tmpname = tempfile.mkstemp(suffix=".png")
968        os.close(tmpfd)
969        plot.save(name=tmpname)
970        image = mpimg.imread(tmpname)
971        os.remove(tmpname)
972        return [zlib.compress(image.dumps())]
973
974    def compare(self, new_result, old_result):
975        compare_image_lists(new_result, old_result, self.decimals)
976
977
978class GenericArrayTest(AnswerTestingTest):
979    _type_name = "GenericArray"
980    _attrs = ("array_func_name", "args", "kwargs")
981
982    def __init__(self, ds_fn, array_func, args=None, kwargs=None, decimals=None):
983        super().__init__(ds_fn)
984        self.array_func = array_func
985        self.array_func_name = array_func.__name__
986        self.args = args
987        self.kwargs = kwargs
988        self.decimals = decimals
989
990    def run(self):
991        if self.args is None:
992            args = []
993        else:
994            args = self.args
995        if self.kwargs is None:
996            kwargs = {}
997        else:
998            kwargs = self.kwargs
999        return self.array_func(*args, **kwargs)
1000
1001    def compare(self, new_result, old_result):
1002        if not isinstance(new_result, dict):
1003            new_result = {"answer": new_result}
1004            old_result = {"answer": old_result}
1005
1006        assert_equal(
1007            len(new_result),
1008            len(old_result),
1009            err_msg="Number of outputs not equal.",
1010            verbose=True,
1011        )
1012        for k in new_result:
1013            if hasattr(new_result[k], "d"):
1014                new_result[k] = new_result[k].d
1015            if hasattr(old_result[k], "d"):
1016                old_result[k] = old_result[k].d
1017            if self.decimals is None:
1018                assert_almost_equal(new_result[k], old_result[k])
1019            else:
1020                assert_allclose_units(
1021                    new_result[k], old_result[k], 10 ** (-self.decimals)
1022                )
1023
1024
1025class GenericImageTest(AnswerTestingTest):
1026    _type_name = "GenericImage"
1027    _attrs = ("image_func_name", "args", "kwargs")
1028
1029    def __init__(self, ds_fn, image_func, decimals, args=None, kwargs=None):
1030        super().__init__(ds_fn)
1031        self.image_func = image_func
1032        self.image_func_name = image_func.__name__
1033        self.args = args
1034        self.kwargs = kwargs
1035        self.decimals = decimals
1036
1037    def run(self):
1038        if self.args is None:
1039            args = []
1040        else:
1041            args = self.args
1042        if self.kwargs is None:
1043            kwargs = {}
1044        else:
1045            kwargs = self.kwargs
1046        comp_imgs = []
1047        tmpdir = tempfile.mkdtemp()
1048        image_prefix = os.path.join(tmpdir, "test_img")
1049        self.image_func(image_prefix, *args, **kwargs)
1050        imgs = sorted(glob.glob(image_prefix + "*"))
1051        assert len(imgs) > 0
1052        for img in imgs:
1053            img_data = mpimg.imread(img)
1054            os.remove(img)
1055            comp_imgs.append(zlib.compress(img_data.dumps()))
1056        return comp_imgs
1057
1058    def compare(self, new_result, old_result):
1059        compare_image_lists(new_result, old_result, self.decimals)
1060
1061
1062class AxialPixelizationTest(AnswerTestingTest):
1063    # This test is typically used once per geometry or coordinates type.
1064    # Feed it a dataset, and it checks that the results of basic pixelization
1065    # don't change.
1066    _type_name = "AxialPixelization"
1067    _attrs = ("geometry",)
1068
1069    def __init__(self, ds_fn, decimals=None):
1070        super().__init__(ds_fn)
1071        self.decimals = decimals
1072        self.geometry = self.ds.coordinates.name
1073
1074    def run(self):
1075        rv = {}
1076        ds = self.ds
1077        for i, axis in enumerate(ds.coordinates.axis_order):
1078            (bounds, center, display_center) = pw.get_window_parameters(
1079                axis, ds.domain_center, None, ds
1080            )
1081            slc = ds.slice(axis, center[i])
1082            xax = ds.coordinates.axis_name[ds.coordinates.x_axis[axis]]
1083            yax = ds.coordinates.axis_name[ds.coordinates.y_axis[axis]]
1084            pix_x = ds.coordinates.pixelize(axis, slc, ("gas", xax), bounds, (512, 512))
1085            pix_y = ds.coordinates.pixelize(axis, slc, ("gas", yax), bounds, (512, 512))
1086            # Wipe out invalid values (fillers)
1087            pix_x[~np.isfinite(pix_x)] = 0.0
1088            pix_y[~np.isfinite(pix_y)] = 0.0
1089            rv[f"{axis}_x"] = pix_x
1090            rv[f"{axis}_y"] = pix_y
1091        return rv
1092
1093    def compare(self, new_result, old_result):
1094        assert_equal(
1095            len(new_result),
1096            len(old_result),
1097            err_msg="Number of outputs not equal.",
1098            verbose=True,
1099        )
1100        for k in new_result:
1101            if hasattr(new_result[k], "d"):
1102                new_result[k] = new_result[k].d
1103            if hasattr(old_result[k], "d"):
1104                old_result[k] = old_result[k].d
1105            if self.decimals is None:
1106                assert_almost_equal(new_result[k], old_result[k])
1107            else:
1108                assert_allclose_units(
1109                    new_result[k], old_result[k], 10 ** (-self.decimals)
1110                )
1111
1112
1113def requires_sim(sim_fn, sim_type, big_data=False, file_check=False):
1114    issue_deprecation_warning(
1115        "This function is no longer used in the "
1116        "yt project testing framework and is "
1117        "targeted for deprecation.",
1118        since="4.0.0",
1119        removal="4.1.0",
1120    )
1121
1122    from functools import wraps
1123
1124    from nose import SkipTest
1125
1126    def ffalse(func):
1127        @wraps(func)
1128        def fskip(*args, **kwargs):
1129            raise SkipTest
1130
1131        return fskip
1132
1133    def ftrue(func):
1134        return func
1135
1136    if not run_big_data and big_data:
1137        return ffalse
1138    elif not can_run_sim(sim_fn, sim_type, file_check):
1139        return ffalse
1140    else:
1141        return ftrue
1142
1143
1144def requires_answer_testing():
1145    from functools import wraps
1146
1147    from nose import SkipTest
1148
1149    def ffalse(func):
1150        @wraps(func)
1151        def fskip(*args, **kwargs):
1152            raise SkipTest
1153
1154        return fskip
1155
1156    def ftrue(func):
1157        return func
1158
1159    if AnswerTestingTest.result_storage is not None:
1160        return ftrue
1161    else:
1162        return ffalse
1163
1164
1165def requires_ds(ds_fn, big_data=False, file_check=False):
1166    from functools import wraps
1167
1168    from nose import SkipTest
1169
1170    def ffalse(func):
1171        @wraps(func)
1172        def fskip(*args, **kwargs):
1173            raise SkipTest
1174
1175        return fskip
1176
1177    def ftrue(func):
1178        return func
1179
1180    if not run_big_data and big_data:
1181        return ffalse
1182    elif not can_run_ds(ds_fn, file_check):
1183        return ffalse
1184    else:
1185        return ftrue
1186
1187
1188def small_patch_amr(ds_fn, fields, input_center="max", input_weight=("gas", "density")):
1189    if not can_run_ds(ds_fn):
1190        return
1191    dso = [None, ("sphere", (input_center, (0.1, "unitary")))]
1192    yield GridHierarchyTest(ds_fn)
1193    yield ParentageRelationshipsTest(ds_fn)
1194    for field in fields:
1195        yield GridValuesTest(ds_fn, field)
1196        for axis in [0, 1, 2]:
1197            for dobj_name in dso:
1198                for weight_field in [None, input_weight]:
1199                    yield ProjectionValuesTest(
1200                        ds_fn, axis, field, weight_field, dobj_name
1201                    )
1202                yield FieldValuesTest(ds_fn, field, dobj_name)
1203
1204
1205def big_patch_amr(ds_fn, fields, input_center="max", input_weight=("gas", "density")):
1206    if not can_run_ds(ds_fn):
1207        return
1208    dso = [None, ("sphere", (input_center, (0.1, "unitary")))]
1209    yield GridHierarchyTest(ds_fn)
1210    yield ParentageRelationshipsTest(ds_fn)
1211    for field in fields:
1212        yield GridValuesTest(ds_fn, field)
1213        for axis in [0, 1, 2]:
1214            for dobj_name in dso:
1215                for weight_field in [None, input_weight]:
1216                    yield PixelizedProjectionValuesTest(
1217                        ds_fn, axis, field, weight_field, dobj_name
1218                    )
1219
1220
1221def _particle_answers(
1222    ds, ds_str_repr, ds_nparticles, fields, proj_test_class, center="c"
1223):
1224    if not can_run_ds(ds):
1225        return
1226    assert_equal(str(ds), ds_str_repr)
1227    dso = [None, ("sphere", (center, (0.1, "unitary")))]
1228    dd = ds.all_data()
1229    # this needs to explicitly be "all"
1230    assert_equal(dd["all", "particle_position"].shape, (ds_nparticles, 3))
1231    tot = sum(
1232        dd[ptype, "particle_position"].shape[0] for ptype in ds.particle_types_raw
1233    )
1234    assert_equal(tot, ds_nparticles)
1235    for dobj_name in dso:
1236        for field, weight_field in fields.items():
1237            particle_type = field[0] in ds.particle_types
1238            for axis in [0, 1, 2]:
1239                if not particle_type:
1240                    yield proj_test_class(ds, axis, field, weight_field, dobj_name)
1241            yield FieldValuesTest(ds, field, dobj_name, particle_type=particle_type)
1242
1243
1244def nbody_answer(ds, ds_str_repr, ds_nparticles, fields, center="c"):
1245    return _particle_answers(
1246        ds,
1247        ds_str_repr,
1248        ds_nparticles,
1249        fields,
1250        PixelizedParticleProjectionValuesTest,
1251        center=center,
1252    )
1253
1254
1255def sph_answer(ds, ds_str_repr, ds_nparticles, fields, center="c"):
1256    return _particle_answers(
1257        ds,
1258        ds_str_repr,
1259        ds_nparticles,
1260        fields,
1261        PixelizedProjectionValuesTest,
1262        center=center,
1263    )
1264
1265
1266def create_obj(ds, obj_type):
1267    # obj_type should be tuple of
1268    #  ( obj_name, ( args ) )
1269    if obj_type is None:
1270        return ds.all_data()
1271    cls = getattr(ds, obj_type[0])
1272    obj = cls(*obj_type[1])
1273    return obj
1274