1import os 2import shutil 3import tempfile 4import unittest 5from unittest import mock 6 7import numpy as np 8 9from yt.data_objects.particle_filters import add_particle_filter 10from yt.data_objects.profiles import create_profile 11from yt.loaders import load 12from yt.testing import ( 13 assert_allclose, 14 assert_array_almost_equal, 15 fake_particle_ds, 16 requires_file, 17) 18from yt.units.yt_array import YTArray 19from yt.utilities.answer_testing.framework import ( 20 PhasePlotAttributeTest, 21 PlotWindowAttributeTest, 22 data_dir_load, 23 requires_ds, 24) 25from yt.visualization.api import ParticlePhasePlot, ParticlePlot, ParticleProjectionPlot 26from yt.visualization.tests.test_plotwindow import ATTR_ARGS, WIDTH_SPECS 27 28 29def setup(): 30 """Test specific setup.""" 31 from yt.config import ytcfg 32 33 ytcfg["yt", "internals", "within_testing"] = True 34 35 36# override some of the plotwindow ATTR_ARGS 37PROJ_ATTR_ARGS = ATTR_ARGS.copy() 38PROJ_ATTR_ARGS["set_cmap"] = [ 39 ((("all", "particle_mass"), "RdBu"), {}), 40 ((("all", "particle_mass"), "kamae"), {}), 41] 42PROJ_ATTR_ARGS["set_log"] = [((("all", "particle_mass"), False), {})] 43PROJ_ATTR_ARGS["set_zlim"] = [ 44 ((("all", "particle_mass"), 1e39, 1e42), {}), 45 ((("all", "particle_mass"), 1e39, None), {"dynamic_range": 4}), 46] 47 48PHASE_ATTR_ARGS = { 49 "annotate_text": [ 50 (((5e-29, 5e7), "Hello YT"), {}), 51 (((5e-29, 5e7), "Hello YT"), {"color": "b"}), 52 ], 53 "set_title": [((("all", "particle_mass"), "A phase plot."), {})], 54 "set_log": [((("all", "particle_mass"), False), {})], 55 "set_unit": [((("all", "particle_mass"), "Msun"), {})], 56 "set_xlim": [((-4e7, 4e7), {})], 57 "set_ylim": [((-4e7, 4e7), {})], 58} 59 60TEST_FLNMS = [None, "test", "test.png", "test.eps", "test.ps", "test.pdf"] 61 62CENTER_SPECS = ( 63 "c", 64 "C", 65 "center", 66 "Center", 67 [0.5, 0.5, 0.5], 68 [[0.2, 0.3, 0.4], "cm"], 69 YTArray([0.3, 0.4, 0.7], "cm"), 70) 71 72WEIGHT_FIELDS = (None, ("all", "particle_ones"), ("all", "particle_mass")) 73 74PHASE_FIELDS = [ 75 ( 76 ("all", "particle_velocity_x"), 77 ("all", "particle_position_z"), 78 ("all", "particle_mass"), 79 ), 80 ( 81 ("all", "particle_position_x"), 82 ("all", "particle_position_y"), 83 ("all", "particle_ones"), 84 ), 85 ( 86 ("all", "particle_velocity_x"), 87 ("all", "particle_velocity_y"), 88 [("all", "particle_mass"), ("all", "particle_ones")], 89 ), 90] 91 92 93g30 = "IsolatedGalaxy/galaxy0030/galaxy0030" 94 95 96@requires_ds(g30, big_data=True) 97def test_particle_projection_answers(): 98 """ 99 100 This iterates over the all the plot modification functions in 101 PROJ_ATTR_ARGS. Each time, it compares the images produced by 102 ParticleProjectionPlot to the gold standard. 103 104 105 """ 106 107 plot_field = ("all", "particle_mass") 108 decimals = 12 109 ds = data_dir_load(g30) 110 for ax in "xyz": 111 for attr_name in PROJ_ATTR_ARGS.keys(): 112 for args in PROJ_ATTR_ARGS[attr_name]: 113 test = PlotWindowAttributeTest( 114 ds, 115 plot_field, 116 ax, 117 attr_name, 118 args, 119 decimals, 120 "ParticleProjectionPlot", 121 ) 122 test_particle_projection_answers.__name__ = test.description 123 yield test 124 125 126@requires_ds(g30, big_data=True) 127def test_particle_projection_filter(): 128 """ 129 130 This tests particle projection plots for filter fields. 131 132 133 """ 134 135 def formed_star(pfilter, data): 136 filter = data["all", "creation_time"] > 0 137 return filter 138 139 add_particle_filter( 140 "formed_star", 141 function=formed_star, 142 filtered_type="all", 143 requires=["creation_time"], 144 ) 145 146 plot_field = ("formed_star", "particle_mass") 147 148 decimals = 12 149 ds = data_dir_load(g30) 150 ds.add_particle_filter("formed_star") 151 for ax in "xyz": 152 attr_name = "set_log" 153 for args in PROJ_ATTR_ARGS[attr_name]: 154 test = PlotWindowAttributeTest( 155 ds, plot_field, ax, attr_name, args, decimals, "ParticleProjectionPlot" 156 ) 157 test_particle_projection_filter.__name__ = test.description 158 yield test 159 160 161@requires_ds(g30, big_data=True) 162def test_particle_phase_answers(): 163 """ 164 165 This iterates over the all the plot modification functions in 166 PHASE_ATTR_ARGS. Each time, it compares the images produced by 167 ParticlePhasePlot to the gold standard. 168 169 """ 170 171 decimals = 12 172 ds = data_dir_load(g30) 173 174 x_field = ("all", "particle_velocity_x") 175 y_field = ("all", "particle_velocity_y") 176 z_field = ("all", "particle_mass") 177 for attr_name in PHASE_ATTR_ARGS.keys(): 178 for args in PHASE_ATTR_ARGS[attr_name]: 179 test = PhasePlotAttributeTest( 180 ds, 181 x_field, 182 y_field, 183 z_field, 184 attr_name, 185 args, 186 decimals, 187 "ParticlePhasePlot", 188 ) 189 190 test_particle_phase_answers.__name__ = test.description 191 yield test 192 193 194class TestParticlePhasePlotSave(unittest.TestCase): 195 def setUp(self): 196 self.tmpdir = tempfile.mkdtemp() 197 self.curdir = os.getcwd() 198 os.chdir(self.tmpdir) 199 200 def tearDown(self): 201 os.chdir(self.curdir) 202 shutil.rmtree(self.tmpdir) 203 204 def test_particle_phase_plot(self): 205 test_ds = fake_particle_ds() 206 data_sources = [ 207 test_ds.region([0.5] * 3, [0.4] * 3, [0.6] * 3), 208 test_ds.all_data(), 209 ] 210 particle_phases = [] 211 212 for source in data_sources: 213 for x_field, y_field, z_fields in PHASE_FIELDS: 214 particle_phases.append( 215 ParticlePhasePlot( 216 source, 217 x_field, 218 y_field, 219 z_fields, 220 x_bins=16, 221 y_bins=16, 222 ) 223 ) 224 225 particle_phases.append( 226 ParticlePhasePlot( 227 source, 228 x_field, 229 y_field, 230 z_fields, 231 x_bins=16, 232 y_bins=16, 233 deposition="cic", 234 ) 235 ) 236 237 pp = create_profile( 238 source, 239 [x_field, y_field], 240 z_fields, 241 weight_field=("all", "particle_ones"), 242 n_bins=[16, 16], 243 ) 244 245 particle_phases.append(ParticlePhasePlot.from_profile(pp)) 246 particle_phases[0]._repr_html_() 247 248 with mock.patch( 249 "yt.visualization._mpl_imports.FigureCanvasAgg.print_figure" 250 ), mock.patch( 251 "yt.visualization._mpl_imports.FigureCanvasPdf.print_figure" 252 ), mock.patch( 253 "yt.visualization._mpl_imports.FigureCanvasPS.print_figure" 254 ): 255 for p in particle_phases: 256 for fname in TEST_FLNMS: 257 p.save(fname) 258 259 260tgal = "TipsyGalaxy/galaxy.00300" 261 262 263@requires_file(tgal) 264def test_particle_phase_plot_semantics(): 265 ds = load(tgal) 266 ad = ds.all_data() 267 dens_ex = ad.quantities.extrema(("Gas", "density")) 268 temp_ex = ad.quantities.extrema(("Gas", "temperature")) 269 plot = ParticlePlot( 270 ds, ("Gas", "density"), ("Gas", "temperature"), ("Gas", "particle_mass") 271 ) 272 plot.set_log(("Gas", "density"), True) 273 plot.set_log(("Gas", "temperature"), True) 274 p = plot.profile 275 276 # bin extrema are field extrema 277 assert dens_ex[0] - np.spacing(dens_ex[0]) == p.x_bins[0] 278 assert dens_ex[-1] + np.spacing(dens_ex[-1]) == p.x_bins[-1] 279 assert temp_ex[0] - np.spacing(temp_ex[0]) == p.y_bins[0] 280 assert temp_ex[-1] + np.spacing(temp_ex[-1]) == p.y_bins[-1] 281 282 # bins are evenly spaced in log space 283 logxbins = np.log10(p.x_bins) 284 dxlogxbins = logxbins[1:] - logxbins[:-1] 285 assert_allclose(dxlogxbins, dxlogxbins[0]) 286 287 logybins = np.log10(p.y_bins) 288 dylogybins = logybins[1:] - logybins[:-1] 289 assert_allclose(dylogybins, dylogybins[0]) 290 291 plot.set_log(("Gas", "density"), False) 292 plot.set_log(("Gas", "temperature"), False) 293 p = plot.profile 294 295 # bin extrema are field extrema 296 assert dens_ex[0] - np.spacing(dens_ex[0]) == p.x_bins[0] 297 assert dens_ex[-1] + np.spacing(dens_ex[-1]) == p.x_bins[-1] 298 assert temp_ex[0] - np.spacing(temp_ex[0]) == p.y_bins[0] 299 assert temp_ex[-1] + np.spacing(temp_ex[-1]) == p.y_bins[-1] 300 301 # bins are evenly spaced in log space 302 dxbins = p.x_bins[1:] - p.x_bins[:-1] 303 assert_allclose(dxbins, dxbins[0]) 304 305 dybins = p.y_bins[1:] - p.y_bins[:-1] 306 assert_allclose(dybins, dybins[0]) 307 308 309@requires_file(tgal) 310def test_set_units(): 311 ds = load(tgal) 312 sp = ds.sphere("max", (1.0, "Mpc")) 313 pp = ParticlePhasePlot( 314 sp, ("Gas", "density"), ("Gas", "temperature"), ("Gas", "particle_mass") 315 ) 316 # make sure we can set the units using the tuple without erroring out 317 pp.set_unit(("Gas", "particle_mass"), "Msun") 318 319 320@requires_file(tgal) 321def test_switch_ds(): 322 """ 323 Tests the _switch_ds() method for ParticleProjectionPlots that as of 324 25th October 2017 requires a specific hack in plot_container.py 325 """ 326 ds = load(tgal) 327 ds2 = load(tgal) 328 329 plot = ParticlePlot( 330 ds, 331 ("Gas", "particle_position_x"), 332 ("Gas", "particle_position_y"), 333 ("Gas", "density"), 334 ) 335 336 plot._switch_ds(ds2) 337 338 return 339 340 341class TestParticleProjectionPlotSave(unittest.TestCase): 342 def setUp(self): 343 self.tmpdir = tempfile.mkdtemp() 344 self.curdir = os.getcwd() 345 os.chdir(self.tmpdir) 346 347 def tearDown(self): 348 os.chdir(self.curdir) 349 shutil.rmtree(self.tmpdir) 350 351 def test_particle_plot(self): 352 test_ds = fake_particle_ds() 353 particle_projs = [] 354 for dim in range(3): 355 particle_projs += [ 356 ParticleProjectionPlot(test_ds, dim, ("all", "particle_mass")), 357 ParticleProjectionPlot( 358 test_ds, dim, ("all", "particle_mass"), deposition="cic" 359 ), 360 ParticleProjectionPlot( 361 test_ds, dim, ("all", "particle_mass"), density=True 362 ), 363 ] 364 particle_projs[0]._repr_html_() 365 with mock.patch( 366 "yt.visualization._mpl_imports.FigureCanvasAgg.print_figure" 367 ), mock.patch( 368 "yt.visualization._mpl_imports.FigureCanvasPdf.print_figure" 369 ), mock.patch( 370 "yt.visualization._mpl_imports.FigureCanvasPS.print_figure" 371 ): 372 for p in particle_projs: 373 for fname in TEST_FLNMS: 374 p.save(fname)[0] 375 376 def test_particle_plot_ds(self): 377 test_ds = fake_particle_ds() 378 ds_region = test_ds.region([0.5] * 3, [0.4] * 3, [0.6] * 3) 379 for dim in range(3): 380 pplot_ds = ParticleProjectionPlot( 381 test_ds, dim, ("all", "particle_mass"), data_source=ds_region 382 ) 383 with mock.patch( 384 "yt.visualization._mpl_imports.FigureCanvasAgg.print_figure" 385 ): 386 pplot_ds.save() 387 388 def test_particle_plot_c(self): 389 test_ds = fake_particle_ds() 390 for center in CENTER_SPECS: 391 for dim in range(3): 392 pplot_c = ParticleProjectionPlot( 393 test_ds, dim, ("all", "particle_mass"), center=center 394 ) 395 with mock.patch( 396 "yt.visualization._mpl_imports.FigureCanvasAgg.print_figure" 397 ): 398 pplot_c.save() 399 400 def test_particle_plot_wf(self): 401 test_ds = fake_particle_ds() 402 for dim in range(3): 403 for weight_field in WEIGHT_FIELDS: 404 pplot_wf = ParticleProjectionPlot( 405 test_ds, dim, ("all", "particle_mass"), weight_field=weight_field 406 ) 407 with mock.patch( 408 "yt.visualization._mpl_imports.FigureCanvasAgg.print_figure" 409 ): 410 pplot_wf.save() 411 412 def test_creation_with_width(self): 413 test_ds = fake_particle_ds() 414 for width, (xlim, ylim, pwidth, _aun) in WIDTH_SPECS.items(): 415 plot = ParticleProjectionPlot( 416 test_ds, 0, ("all", "particle_mass"), width=width 417 ) 418 419 xlim = [plot.ds.quan(el[0], el[1]) for el in xlim] 420 ylim = [plot.ds.quan(el[0], el[1]) for el in ylim] 421 pwidth = [plot.ds.quan(el[0], el[1]) for el in pwidth] 422 423 [assert_array_almost_equal(px, x, 14) for px, x in zip(plot.xlim, xlim)] 424 [assert_array_almost_equal(py, y, 14) for py, y in zip(plot.ylim, ylim)] 425 [assert_array_almost_equal(pw, w, 14) for pw, w in zip(plot.width, pwidth)] 426 427 428def test_particle_plot_instance(): 429 """ 430 Tests the type of plot instance returned by ParticlePlot. 431 432 If x_field and y_field are any combination of valid particle_position in x, 433 y or z axis,then ParticleProjectionPlot instance is expected. 434 435 436 """ 437 ds = fake_particle_ds() 438 x_field = ("all", "particle_position_x") 439 y_field = ("all", "particle_position_y") 440 z_field = ("all", "particle_velocity_x") 441 442 plot = ParticlePlot(ds, x_field, y_field) 443 assert isinstance(plot, ParticleProjectionPlot) 444 445 plot = ParticlePlot(ds, y_field, x_field) 446 assert isinstance(plot, ParticleProjectionPlot) 447 448 plot = ParticlePlot(ds, x_field, z_field) 449 assert isinstance(plot, ParticlePhasePlot) 450