1from yt.frontends.athena.api import AthenaDataset
2from yt.loaders import load
3from yt.testing import (
4    assert_allclose_units,
5    assert_equal,
6    disable_dataset_cache,
7    requires_file,
8)
9from yt.utilities.answer_testing.framework import (
10    data_dir_load,
11    requires_ds,
12    small_patch_amr,
13)
14
15_fields_cloud = (("athena", "scalar[0]"), ("gas", "density"), ("gas", "total_energy"))
16
17cloud = "ShockCloud/id0/Cloud.0050.vtk"
18
19
20@requires_ds(cloud)
21def test_cloud():
22    ds = data_dir_load(cloud)
23    assert_equal(str(ds), "Cloud.0050")
24    for test in small_patch_amr(ds, _fields_cloud):
25        test_cloud.__name__ = test.description
26        yield test
27
28
29_fields_blast = (
30    ("gas", "temperature"),
31    ("gas", "density"),
32    ("gas", "velocity_magnitude"),
33)
34
35blast = "MHDBlast/id0/Blast.0100.vtk"
36
37
38@requires_ds(blast)
39def test_blast():
40    ds = data_dir_load(blast)
41    assert_equal(str(ds), "Blast.0100")
42    for test in small_patch_amr(ds, _fields_blast):
43        test_blast.__name__ = test.description
44        yield test
45
46
47uo_blast = {
48    "length_unit": (1.0, "pc"),
49    "mass_unit": (2.38858753789e-24, "g/cm**3*pc**3"),
50    "time_unit": (1.0, "s*pc/km"),
51}
52
53
54@requires_file(blast)
55def test_blast_override():
56    # verify that overriding units causes derived unit values to be updated.
57    # see issue #1259
58    ds = load(blast, units_override=uo_blast)
59    assert_equal(float(ds.magnetic_unit.in_units("gauss")), 5.47867467969813e-07)
60
61
62uo_stripping = {
63    "time_unit": 3.086e14,
64    "length_unit": 8.0236e22,
65    "mass_unit": 9.999e-30 * 8.0236e22 ** 3,
66}
67
68_fields_stripping = (
69    ("gas", "temperature"),
70    ("gas", "density"),
71    ("athena", "specific_scalar[0]"),
72)
73
74stripping = "RamPressureStripping/id0/rps.0062.vtk"
75
76
77@requires_ds(stripping, big_data=True)
78def test_stripping():
79    ds = data_dir_load(stripping, kwargs={"units_override": uo_stripping})
80    assert_equal(str(ds), "rps.0062")
81    for test in small_patch_amr(ds, _fields_stripping):
82        test_stripping.__name__ = test.description
83        yield test
84
85
86sloshing = "MHDSloshing/virgo_low_res.0054.vtk"
87
88uo_sloshing = {
89    "length_unit": (1.0, "Mpc"),
90    "time_unit": (1.0, "Myr"),
91    "mass_unit": (1.0e14, "Msun"),
92}
93
94
95@requires_file(sloshing)
96@disable_dataset_cache
97def test_nprocs():
98    ds1 = load(sloshing, units_override=uo_sloshing)
99    sp1 = ds1.sphere("c", (100.0, "kpc"))
100    prj1 = ds1.proj(("gas", "density"), 0)
101    ds2 = load(sloshing, units_override=uo_sloshing, nprocs=8)
102    sp2 = ds2.sphere("c", (100.0, "kpc"))
103    prj2 = ds1.proj(("gas", "density"), 0)
104
105    assert_equal(
106        sp1.quantities.extrema(("gas", "pressure")),
107        sp2.quantities.extrema(("gas", "pressure")),
108    )
109    assert_allclose_units(
110        sp1.quantities.total_quantity(("gas", "pressure")),
111        sp2.quantities.total_quantity(("gas", "pressure")),
112    )
113    for ax in "xyz":
114        assert_equal(
115            sp1.quantities.extrema(("gas", f"velocity_{ax}")),
116            sp2.quantities.extrema(("gas", f"velocity_{ax}")),
117        )
118    assert_allclose_units(
119        sp1.quantities.bulk_velocity(), sp2.quantities.bulk_velocity()
120    )
121    assert_equal(prj1[("gas", "density")], prj2[("gas", "density")])
122
123
124@requires_file(cloud)
125def test_AthenaDataset():
126    assert isinstance(data_dir_load(cloud), AthenaDataset)
127