1import os
2import shutil
3import tempfile
4
5import numpy as np
6from nose.tools import assert_raises
7
8from yt.data_objects.particle_filters import add_particle_filter, particle_filter
9from yt.testing import assert_equal, fake_random_ds, fake_sph_grid_ds
10from yt.utilities.exceptions import YTIllDefinedFilter, YTIllDefinedParticleFilter
11from yt.visualization.plot_window import ProjectionPlot
12
13
14def test_add_particle_filter():
15    """Test particle filters created via add_particle_filter
16
17    This accesses a deposition field using the particle filter, which was a
18    problem in previous versions on this dataset because there are chunks with
19    no stars in them.
20
21    """
22
23    def stars(pfilter, data):
24        filter_field = (pfilter.filtered_type, "particle_mass")
25        return data[filter_field] > 0.5
26
27    add_particle_filter(
28        "stars1", function=stars, filtered_type="all", requires=["particle_mass"]
29    )
30    ds = fake_random_ds(16, nprocs=8, particles=16)
31    ds.add_particle_filter("stars1")
32    assert ("deposit", "stars1_cic") in ds.derived_field_list
33
34    # Test without requires field
35    add_particle_filter("stars2", function=stars)
36    ds = fake_random_ds(16, nprocs=8, particles=16)
37    ds.add_particle_filter("stars2")
38    assert ("deposit", "stars2_cic") in ds.derived_field_list
39
40    # Test adding filter with fields not defined on the ds
41    with assert_raises(YTIllDefinedParticleFilter) as ex:
42        add_particle_filter(
43            "bad_stars", function=stars, filtered_type="all", requires=["wrong_field"]
44        )
45        ds.add_particle_filter("bad_stars")
46    actual = str(ex.exception)
47    desired = (
48        "\nThe fields\n\t('all', 'wrong_field'),\nrequired by the"
49        ' "bad_stars" particle filter, are not defined for this dataset.'
50    )
51    assert_equal(actual, desired)
52
53
54def test_add_particle_filter_overriding():
55    """Test the add_particle_filter overriding"""
56    from yt.data_objects.particle_filters import filter_registry
57    from yt.funcs import mylog
58
59    def star_0(pfilter, data):
60        pass
61
62    def star_1(pfilter, data):
63        pass
64
65    # Use a closure to store whether the warning was called
66    def closure(status):
67        def warning_patch(*args, **kwargs):
68            status[0] = True
69
70        def was_called():
71            return status[0]
72
73        return warning_patch, was_called
74
75    ## Test 1: we add a dummy particle filter
76    add_particle_filter(
77        "dummy", function=star_0, filtered_type="all", requires=["creation_time"]
78    )
79    assert "dummy" in filter_registry
80    assert_equal(filter_registry["dummy"].function, star_0)
81
82    ## Test 2: we add another dummy particle filter.
83    ##         a warning is expected. We use the above closure to
84    ##         check that.
85    # Store the original warning function
86    warning = mylog.warning
87    monkey_warning, monkey_patch_was_called = closure([False])
88    mylog.warning = monkey_warning
89    add_particle_filter(
90        "dummy", function=star_1, filtered_type="all", requires=["creation_time"]
91    )
92    assert_equal(filter_registry["dummy"].function, star_1)
93    assert_equal(monkey_patch_was_called(), True)
94
95    # Restore the original warning function
96    mylog.warning = warning
97
98
99def test_particle_filter_decorator():
100    """Test the particle_filter decorator"""
101
102    @particle_filter(filtered_type="all", requires=["particle_mass"])
103    def heavy_stars(pfilter, data):
104        filter_field = (pfilter.filtered_type, "particle_mass")
105        return data[filter_field] > 0.5
106
107    ds = fake_random_ds(16, nprocs=8, particles=16)
108    ds.add_particle_filter("heavy_stars")
109    assert "heavy_stars" in ds.particle_types
110    assert ("deposit", "heavy_stars_cic") in ds.derived_field_list
111
112    # Test name of particle filter
113    @particle_filter(name="my_stars", filtered_type="all", requires=["particle_mass"])
114    def custom_stars(pfilter, data):
115        filter_field = (pfilter.filtered_type, "particle_mass")
116        return data[filter_field] == 0.5
117
118    ds = fake_random_ds(16, nprocs=8, particles=16)
119    ds.add_particle_filter("my_stars")
120    assert "my_stars" in ds.particle_types
121    assert ("deposit", "my_stars_cic") in ds.derived_field_list
122
123
124def test_particle_filter_exceptions():
125    @particle_filter(filtered_type="all", requires=["particle_mass"])
126    def filter1(pfilter, data):
127        return data
128
129    ds = fake_random_ds(16, nprocs=8, particles=16)
130    ds.add_particle_filter("filter1")
131
132    ad = ds.all_data()
133    with assert_raises(YTIllDefinedFilter):
134        ad["filter1", "particle_mass"].shape[0]
135
136    @particle_filter(filtered_type="all", requires=["particle_mass"])
137    def filter2(pfilter, data):
138        filter_field = ("io", "particle_mass")
139        return data[filter_field] > 0.5
140
141    ds.add_particle_filter("filter2")
142    ad = ds.all_data()
143    ad["filter2", "particle_mass"].min()
144
145
146def test_particle_filter_dependency():
147    """
148    Test dataset add_particle_filter which should automatically add
149    the dependency of the filter.
150    """
151
152    @particle_filter(filtered_type="all", requires=["particle_mass"])
153    def h_stars(pfilter, data):
154        filter_field = (pfilter.filtered_type, "particle_mass")
155        return data[filter_field] > 0.5
156
157    @particle_filter(filtered_type="h_stars", requires=["particle_mass"])
158    def hh_stars(pfilter, data):
159        filter_field = (pfilter.filtered_type, "particle_mass")
160        return data[filter_field] > 0.9
161
162    ds = fake_random_ds(16, nprocs=8, particles=16)
163    ds.add_particle_filter("hh_stars")
164    assert "hh_stars" in ds.particle_types
165    assert "h_stars" in ds.particle_types
166    assert ("deposit", "hh_stars_cic") in ds.derived_field_list
167    assert ("deposit", "h_stars_cic") in ds.derived_field_list
168
169
170def test_covering_grid_particle_filter():
171    @particle_filter(filtered_type="all", requires=["particle_mass"])
172    def heavy_stars(pfilter, data):
173        filter_field = (pfilter.filtered_type, "particle_mass")
174        return data[filter_field] > 0.5
175
176    ds = fake_random_ds(16, nprocs=8, particles=16)
177    ds.add_particle_filter("heavy_stars")
178
179    for grid in ds.index.grids:
180        cg = ds.covering_grid(grid.Level, grid.LeftEdge, grid.ActiveDimensions)
181
182        assert_equal(
183            cg["heavy_stars", "particle_mass"].shape[0],
184            grid["heavy_stars", "particle_mass"].shape[0],
185        )
186        assert_equal(
187            cg["heavy_stars", "particle_mass"].shape[0],
188            grid["heavy_stars", "particle_mass"].shape[0],
189        )
190
191
192def test_sph_particle_filter_plotting():
193    ds = fake_sph_grid_ds()
194
195    @particle_filter("central_gas", requires=["particle_position"], filtered_type="io")
196    def _filter(pfilter, data):
197        coords = np.abs(data[pfilter.filtered_type, "particle_position"])
198        return (coords[:, 0] < 1.6) & (coords[:, 1] < 1.6) & (coords[:, 2] < 1.6)
199
200    ds.add_particle_filter("central_gas")
201
202    plot = ProjectionPlot(ds, "z", ("central_gas", "density"))
203    tmpdir = tempfile.mkdtemp()
204    curdir = os.getcwd()
205    os.chdir(tmpdir)
206
207    plot.save()
208
209    os.chdir(curdir)
210    shutil.rmtree(tmpdir)
211