1import os 2import shutil 3import tempfile 4 5import numpy as np 6from nose.tools import assert_raises 7 8from yt.data_objects.particle_filters import add_particle_filter, particle_filter 9from yt.testing import assert_equal, fake_random_ds, fake_sph_grid_ds 10from yt.utilities.exceptions import YTIllDefinedFilter, YTIllDefinedParticleFilter 11from yt.visualization.plot_window import ProjectionPlot 12 13 14def test_add_particle_filter(): 15 """Test particle filters created via add_particle_filter 16 17 This accesses a deposition field using the particle filter, which was a 18 problem in previous versions on this dataset because there are chunks with 19 no stars in them. 20 21 """ 22 23 def stars(pfilter, data): 24 filter_field = (pfilter.filtered_type, "particle_mass") 25 return data[filter_field] > 0.5 26 27 add_particle_filter( 28 "stars1", function=stars, filtered_type="all", requires=["particle_mass"] 29 ) 30 ds = fake_random_ds(16, nprocs=8, particles=16) 31 ds.add_particle_filter("stars1") 32 assert ("deposit", "stars1_cic") in ds.derived_field_list 33 34 # Test without requires field 35 add_particle_filter("stars2", function=stars) 36 ds = fake_random_ds(16, nprocs=8, particles=16) 37 ds.add_particle_filter("stars2") 38 assert ("deposit", "stars2_cic") in ds.derived_field_list 39 40 # Test adding filter with fields not defined on the ds 41 with assert_raises(YTIllDefinedParticleFilter) as ex: 42 add_particle_filter( 43 "bad_stars", function=stars, filtered_type="all", requires=["wrong_field"] 44 ) 45 ds.add_particle_filter("bad_stars") 46 actual = str(ex.exception) 47 desired = ( 48 "\nThe fields\n\t('all', 'wrong_field'),\nrequired by the" 49 ' "bad_stars" particle filter, are not defined for this dataset.' 50 ) 51 assert_equal(actual, desired) 52 53 54def test_add_particle_filter_overriding(): 55 """Test the add_particle_filter overriding""" 56 from yt.data_objects.particle_filters import filter_registry 57 from yt.funcs import mylog 58 59 def star_0(pfilter, data): 60 pass 61 62 def star_1(pfilter, data): 63 pass 64 65 # Use a closure to store whether the warning was called 66 def closure(status): 67 def warning_patch(*args, **kwargs): 68 status[0] = True 69 70 def was_called(): 71 return status[0] 72 73 return warning_patch, was_called 74 75 ## Test 1: we add a dummy particle filter 76 add_particle_filter( 77 "dummy", function=star_0, filtered_type="all", requires=["creation_time"] 78 ) 79 assert "dummy" in filter_registry 80 assert_equal(filter_registry["dummy"].function, star_0) 81 82 ## Test 2: we add another dummy particle filter. 83 ## a warning is expected. We use the above closure to 84 ## check that. 85 # Store the original warning function 86 warning = mylog.warning 87 monkey_warning, monkey_patch_was_called = closure([False]) 88 mylog.warning = monkey_warning 89 add_particle_filter( 90 "dummy", function=star_1, filtered_type="all", requires=["creation_time"] 91 ) 92 assert_equal(filter_registry["dummy"].function, star_1) 93 assert_equal(monkey_patch_was_called(), True) 94 95 # Restore the original warning function 96 mylog.warning = warning 97 98 99def test_particle_filter_decorator(): 100 """Test the particle_filter decorator""" 101 102 @particle_filter(filtered_type="all", requires=["particle_mass"]) 103 def heavy_stars(pfilter, data): 104 filter_field = (pfilter.filtered_type, "particle_mass") 105 return data[filter_field] > 0.5 106 107 ds = fake_random_ds(16, nprocs=8, particles=16) 108 ds.add_particle_filter("heavy_stars") 109 assert "heavy_stars" in ds.particle_types 110 assert ("deposit", "heavy_stars_cic") in ds.derived_field_list 111 112 # Test name of particle filter 113 @particle_filter(name="my_stars", filtered_type="all", requires=["particle_mass"]) 114 def custom_stars(pfilter, data): 115 filter_field = (pfilter.filtered_type, "particle_mass") 116 return data[filter_field] == 0.5 117 118 ds = fake_random_ds(16, nprocs=8, particles=16) 119 ds.add_particle_filter("my_stars") 120 assert "my_stars" in ds.particle_types 121 assert ("deposit", "my_stars_cic") in ds.derived_field_list 122 123 124def test_particle_filter_exceptions(): 125 @particle_filter(filtered_type="all", requires=["particle_mass"]) 126 def filter1(pfilter, data): 127 return data 128 129 ds = fake_random_ds(16, nprocs=8, particles=16) 130 ds.add_particle_filter("filter1") 131 132 ad = ds.all_data() 133 with assert_raises(YTIllDefinedFilter): 134 ad["filter1", "particle_mass"].shape[0] 135 136 @particle_filter(filtered_type="all", requires=["particle_mass"]) 137 def filter2(pfilter, data): 138 filter_field = ("io", "particle_mass") 139 return data[filter_field] > 0.5 140 141 ds.add_particle_filter("filter2") 142 ad = ds.all_data() 143 ad["filter2", "particle_mass"].min() 144 145 146def test_particle_filter_dependency(): 147 """ 148 Test dataset add_particle_filter which should automatically add 149 the dependency of the filter. 150 """ 151 152 @particle_filter(filtered_type="all", requires=["particle_mass"]) 153 def h_stars(pfilter, data): 154 filter_field = (pfilter.filtered_type, "particle_mass") 155 return data[filter_field] > 0.5 156 157 @particle_filter(filtered_type="h_stars", requires=["particle_mass"]) 158 def hh_stars(pfilter, data): 159 filter_field = (pfilter.filtered_type, "particle_mass") 160 return data[filter_field] > 0.9 161 162 ds = fake_random_ds(16, nprocs=8, particles=16) 163 ds.add_particle_filter("hh_stars") 164 assert "hh_stars" in ds.particle_types 165 assert "h_stars" in ds.particle_types 166 assert ("deposit", "hh_stars_cic") in ds.derived_field_list 167 assert ("deposit", "h_stars_cic") in ds.derived_field_list 168 169 170def test_covering_grid_particle_filter(): 171 @particle_filter(filtered_type="all", requires=["particle_mass"]) 172 def heavy_stars(pfilter, data): 173 filter_field = (pfilter.filtered_type, "particle_mass") 174 return data[filter_field] > 0.5 175 176 ds = fake_random_ds(16, nprocs=8, particles=16) 177 ds.add_particle_filter("heavy_stars") 178 179 for grid in ds.index.grids: 180 cg = ds.covering_grid(grid.Level, grid.LeftEdge, grid.ActiveDimensions) 181 182 assert_equal( 183 cg["heavy_stars", "particle_mass"].shape[0], 184 grid["heavy_stars", "particle_mass"].shape[0], 185 ) 186 assert_equal( 187 cg["heavy_stars", "particle_mass"].shape[0], 188 grid["heavy_stars", "particle_mass"].shape[0], 189 ) 190 191 192def test_sph_particle_filter_plotting(): 193 ds = fake_sph_grid_ds() 194 195 @particle_filter("central_gas", requires=["particle_position"], filtered_type="io") 196 def _filter(pfilter, data): 197 coords = np.abs(data[pfilter.filtered_type, "particle_position"]) 198 return (coords[:, 0] < 1.6) & (coords[:, 1] < 1.6) & (coords[:, 2] < 1.6) 199 200 ds.add_particle_filter("central_gas") 201 202 plot = ProjectionPlot(ds, "z", ("central_gas", "density")) 203 tmpdir = tempfile.mkdtemp() 204 curdir = os.getcwd() 205 os.chdir(tmpdir) 206 207 plot.save() 208 209 os.chdir(curdir) 210 shutil.rmtree(tmpdir) 211