1import os
2import shutil
3import tempfile
4import unittest
5from unittest import mock
6
7import numpy as np
8
9from yt.data_objects.particle_filters import add_particle_filter
10from yt.data_objects.profiles import create_profile
11from yt.loaders import load
12from yt.testing import (
13    assert_allclose,
14    assert_array_almost_equal,
15    fake_particle_ds,
16    requires_file,
17)
18from yt.units.yt_array import YTArray
19from yt.utilities.answer_testing.framework import (
20    PhasePlotAttributeTest,
21    PlotWindowAttributeTest,
22    data_dir_load,
23    requires_ds,
24)
25from yt.visualization.api import ParticlePhasePlot, ParticlePlot, ParticleProjectionPlot
26from yt.visualization.tests.test_plotwindow import ATTR_ARGS, WIDTH_SPECS
27
28
29def setup():
30    """Test specific setup."""
31    from yt.config import ytcfg
32
33    ytcfg["yt", "internals", "within_testing"] = True
34
35
36#  override some of the plotwindow ATTR_ARGS
37PROJ_ATTR_ARGS = ATTR_ARGS.copy()
38PROJ_ATTR_ARGS["set_cmap"] = [
39    ((("all", "particle_mass"), "RdBu"), {}),
40    ((("all", "particle_mass"), "kamae"), {}),
41]
42PROJ_ATTR_ARGS["set_log"] = [((("all", "particle_mass"), False), {})]
43PROJ_ATTR_ARGS["set_zlim"] = [
44    ((("all", "particle_mass"), 1e39, 1e42), {}),
45    ((("all", "particle_mass"), 1e39, None), {"dynamic_range": 4}),
46]
47
48PHASE_ATTR_ARGS = {
49    "annotate_text": [
50        (((5e-29, 5e7), "Hello YT"), {}),
51        (((5e-29, 5e7), "Hello YT"), {"color": "b"}),
52    ],
53    "set_title": [((("all", "particle_mass"), "A phase plot."), {})],
54    "set_log": [((("all", "particle_mass"), False), {})],
55    "set_unit": [((("all", "particle_mass"), "Msun"), {})],
56    "set_xlim": [((-4e7, 4e7), {})],
57    "set_ylim": [((-4e7, 4e7), {})],
58}
59
60TEST_FLNMS = [None, "test", "test.png", "test.eps", "test.ps", "test.pdf"]
61
62CENTER_SPECS = (
63    "c",
64    "C",
65    "center",
66    "Center",
67    [0.5, 0.5, 0.5],
68    [[0.2, 0.3, 0.4], "cm"],
69    YTArray([0.3, 0.4, 0.7], "cm"),
70)
71
72WEIGHT_FIELDS = (None, ("all", "particle_ones"), ("all", "particle_mass"))
73
74PHASE_FIELDS = [
75    (
76        ("all", "particle_velocity_x"),
77        ("all", "particle_position_z"),
78        ("all", "particle_mass"),
79    ),
80    (
81        ("all", "particle_position_x"),
82        ("all", "particle_position_y"),
83        ("all", "particle_ones"),
84    ),
85    (
86        ("all", "particle_velocity_x"),
87        ("all", "particle_velocity_y"),
88        [("all", "particle_mass"), ("all", "particle_ones")],
89    ),
90]
91
92
93g30 = "IsolatedGalaxy/galaxy0030/galaxy0030"
94
95
96@requires_ds(g30, big_data=True)
97def test_particle_projection_answers():
98    """
99
100    This iterates over the all the plot modification functions in
101    PROJ_ATTR_ARGS. Each time, it compares the images produced by
102    ParticleProjectionPlot to the gold standard.
103
104
105    """
106
107    plot_field = ("all", "particle_mass")
108    decimals = 12
109    ds = data_dir_load(g30)
110    for ax in "xyz":
111        for attr_name in PROJ_ATTR_ARGS.keys():
112            for args in PROJ_ATTR_ARGS[attr_name]:
113                test = PlotWindowAttributeTest(
114                    ds,
115                    plot_field,
116                    ax,
117                    attr_name,
118                    args,
119                    decimals,
120                    "ParticleProjectionPlot",
121                )
122                test_particle_projection_answers.__name__ = test.description
123                yield test
124
125
126@requires_ds(g30, big_data=True)
127def test_particle_projection_filter():
128    """
129
130    This tests particle projection plots for filter fields.
131
132
133    """
134
135    def formed_star(pfilter, data):
136        filter = data["all", "creation_time"] > 0
137        return filter
138
139    add_particle_filter(
140        "formed_star",
141        function=formed_star,
142        filtered_type="all",
143        requires=["creation_time"],
144    )
145
146    plot_field = ("formed_star", "particle_mass")
147
148    decimals = 12
149    ds = data_dir_load(g30)
150    ds.add_particle_filter("formed_star")
151    for ax in "xyz":
152        attr_name = "set_log"
153        for args in PROJ_ATTR_ARGS[attr_name]:
154            test = PlotWindowAttributeTest(
155                ds, plot_field, ax, attr_name, args, decimals, "ParticleProjectionPlot"
156            )
157            test_particle_projection_filter.__name__ = test.description
158            yield test
159
160
161@requires_ds(g30, big_data=True)
162def test_particle_phase_answers():
163    """
164
165    This iterates over the all the plot modification functions in
166    PHASE_ATTR_ARGS. Each time, it compares the images produced by
167    ParticlePhasePlot to the gold standard.
168
169    """
170
171    decimals = 12
172    ds = data_dir_load(g30)
173
174    x_field = ("all", "particle_velocity_x")
175    y_field = ("all", "particle_velocity_y")
176    z_field = ("all", "particle_mass")
177    for attr_name in PHASE_ATTR_ARGS.keys():
178        for args in PHASE_ATTR_ARGS[attr_name]:
179            test = PhasePlotAttributeTest(
180                ds,
181                x_field,
182                y_field,
183                z_field,
184                attr_name,
185                args,
186                decimals,
187                "ParticlePhasePlot",
188            )
189
190            test_particle_phase_answers.__name__ = test.description
191            yield test
192
193
194class TestParticlePhasePlotSave(unittest.TestCase):
195    def setUp(self):
196        self.tmpdir = tempfile.mkdtemp()
197        self.curdir = os.getcwd()
198        os.chdir(self.tmpdir)
199
200    def tearDown(self):
201        os.chdir(self.curdir)
202        shutil.rmtree(self.tmpdir)
203
204    def test_particle_phase_plot(self):
205        test_ds = fake_particle_ds()
206        data_sources = [
207            test_ds.region([0.5] * 3, [0.4] * 3, [0.6] * 3),
208            test_ds.all_data(),
209        ]
210        particle_phases = []
211
212        for source in data_sources:
213            for x_field, y_field, z_fields in PHASE_FIELDS:
214                particle_phases.append(
215                    ParticlePhasePlot(
216                        source,
217                        x_field,
218                        y_field,
219                        z_fields,
220                        x_bins=16,
221                        y_bins=16,
222                    )
223                )
224
225                particle_phases.append(
226                    ParticlePhasePlot(
227                        source,
228                        x_field,
229                        y_field,
230                        z_fields,
231                        x_bins=16,
232                        y_bins=16,
233                        deposition="cic",
234                    )
235                )
236
237                pp = create_profile(
238                    source,
239                    [x_field, y_field],
240                    z_fields,
241                    weight_field=("all", "particle_ones"),
242                    n_bins=[16, 16],
243                )
244
245                particle_phases.append(ParticlePhasePlot.from_profile(pp))
246        particle_phases[0]._repr_html_()
247
248        with mock.patch(
249            "yt.visualization._mpl_imports.FigureCanvasAgg.print_figure"
250        ), mock.patch(
251            "yt.visualization._mpl_imports.FigureCanvasPdf.print_figure"
252        ), mock.patch(
253            "yt.visualization._mpl_imports.FigureCanvasPS.print_figure"
254        ):
255            for p in particle_phases:
256                for fname in TEST_FLNMS:
257                    p.save(fname)
258
259
260tgal = "TipsyGalaxy/galaxy.00300"
261
262
263@requires_file(tgal)
264def test_particle_phase_plot_semantics():
265    ds = load(tgal)
266    ad = ds.all_data()
267    dens_ex = ad.quantities.extrema(("Gas", "density"))
268    temp_ex = ad.quantities.extrema(("Gas", "temperature"))
269    plot = ParticlePlot(
270        ds, ("Gas", "density"), ("Gas", "temperature"), ("Gas", "particle_mass")
271    )
272    plot.set_log(("Gas", "density"), True)
273    plot.set_log(("Gas", "temperature"), True)
274    p = plot.profile
275
276    # bin extrema are field extrema
277    assert dens_ex[0] - np.spacing(dens_ex[0]) == p.x_bins[0]
278    assert dens_ex[-1] + np.spacing(dens_ex[-1]) == p.x_bins[-1]
279    assert temp_ex[0] - np.spacing(temp_ex[0]) == p.y_bins[0]
280    assert temp_ex[-1] + np.spacing(temp_ex[-1]) == p.y_bins[-1]
281
282    # bins are evenly spaced in log space
283    logxbins = np.log10(p.x_bins)
284    dxlogxbins = logxbins[1:] - logxbins[:-1]
285    assert_allclose(dxlogxbins, dxlogxbins[0])
286
287    logybins = np.log10(p.y_bins)
288    dylogybins = logybins[1:] - logybins[:-1]
289    assert_allclose(dylogybins, dylogybins[0])
290
291    plot.set_log(("Gas", "density"), False)
292    plot.set_log(("Gas", "temperature"), False)
293    p = plot.profile
294
295    # bin extrema are field extrema
296    assert dens_ex[0] - np.spacing(dens_ex[0]) == p.x_bins[0]
297    assert dens_ex[-1] + np.spacing(dens_ex[-1]) == p.x_bins[-1]
298    assert temp_ex[0] - np.spacing(temp_ex[0]) == p.y_bins[0]
299    assert temp_ex[-1] + np.spacing(temp_ex[-1]) == p.y_bins[-1]
300
301    # bins are evenly spaced in log space
302    dxbins = p.x_bins[1:] - p.x_bins[:-1]
303    assert_allclose(dxbins, dxbins[0])
304
305    dybins = p.y_bins[1:] - p.y_bins[:-1]
306    assert_allclose(dybins, dybins[0])
307
308
309@requires_file(tgal)
310def test_set_units():
311    ds = load(tgal)
312    sp = ds.sphere("max", (1.0, "Mpc"))
313    pp = ParticlePhasePlot(
314        sp, ("Gas", "density"), ("Gas", "temperature"), ("Gas", "particle_mass")
315    )
316    # make sure we can set the units using the tuple without erroring out
317    pp.set_unit(("Gas", "particle_mass"), "Msun")
318
319
320@requires_file(tgal)
321def test_switch_ds():
322    """
323    Tests the _switch_ds() method for ParticleProjectionPlots that as of
324    25th October 2017 requires a specific hack in plot_container.py
325    """
326    ds = load(tgal)
327    ds2 = load(tgal)
328
329    plot = ParticlePlot(
330        ds,
331        ("Gas", "particle_position_x"),
332        ("Gas", "particle_position_y"),
333        ("Gas", "density"),
334    )
335
336    plot._switch_ds(ds2)
337
338    return
339
340
341class TestParticleProjectionPlotSave(unittest.TestCase):
342    def setUp(self):
343        self.tmpdir = tempfile.mkdtemp()
344        self.curdir = os.getcwd()
345        os.chdir(self.tmpdir)
346
347    def tearDown(self):
348        os.chdir(self.curdir)
349        shutil.rmtree(self.tmpdir)
350
351    def test_particle_plot(self):
352        test_ds = fake_particle_ds()
353        particle_projs = []
354        for dim in range(3):
355            particle_projs += [
356                ParticleProjectionPlot(test_ds, dim, ("all", "particle_mass")),
357                ParticleProjectionPlot(
358                    test_ds, dim, ("all", "particle_mass"), deposition="cic"
359                ),
360                ParticleProjectionPlot(
361                    test_ds, dim, ("all", "particle_mass"), density=True
362                ),
363            ]
364        particle_projs[0]._repr_html_()
365        with mock.patch(
366            "yt.visualization._mpl_imports.FigureCanvasAgg.print_figure"
367        ), mock.patch(
368            "yt.visualization._mpl_imports.FigureCanvasPdf.print_figure"
369        ), mock.patch(
370            "yt.visualization._mpl_imports.FigureCanvasPS.print_figure"
371        ):
372            for p in particle_projs:
373                for fname in TEST_FLNMS:
374                    p.save(fname)[0]
375
376    def test_particle_plot_ds(self):
377        test_ds = fake_particle_ds()
378        ds_region = test_ds.region([0.5] * 3, [0.4] * 3, [0.6] * 3)
379        for dim in range(3):
380            pplot_ds = ParticleProjectionPlot(
381                test_ds, dim, ("all", "particle_mass"), data_source=ds_region
382            )
383            with mock.patch(
384                "yt.visualization._mpl_imports.FigureCanvasAgg.print_figure"
385            ):
386                pplot_ds.save()
387
388    def test_particle_plot_c(self):
389        test_ds = fake_particle_ds()
390        for center in CENTER_SPECS:
391            for dim in range(3):
392                pplot_c = ParticleProjectionPlot(
393                    test_ds, dim, ("all", "particle_mass"), center=center
394                )
395                with mock.patch(
396                    "yt.visualization._mpl_imports.FigureCanvasAgg.print_figure"
397                ):
398                    pplot_c.save()
399
400    def test_particle_plot_wf(self):
401        test_ds = fake_particle_ds()
402        for dim in range(3):
403            for weight_field in WEIGHT_FIELDS:
404                pplot_wf = ParticleProjectionPlot(
405                    test_ds, dim, ("all", "particle_mass"), weight_field=weight_field
406                )
407                with mock.patch(
408                    "yt.visualization._mpl_imports.FigureCanvasAgg.print_figure"
409                ):
410                    pplot_wf.save()
411
412    def test_creation_with_width(self):
413        test_ds = fake_particle_ds()
414        for width, (xlim, ylim, pwidth, _aun) in WIDTH_SPECS.items():
415            plot = ParticleProjectionPlot(
416                test_ds, 0, ("all", "particle_mass"), width=width
417            )
418
419            xlim = [plot.ds.quan(el[0], el[1]) for el in xlim]
420            ylim = [plot.ds.quan(el[0], el[1]) for el in ylim]
421            pwidth = [plot.ds.quan(el[0], el[1]) for el in pwidth]
422
423            [assert_array_almost_equal(px, x, 14) for px, x in zip(plot.xlim, xlim)]
424            [assert_array_almost_equal(py, y, 14) for py, y in zip(plot.ylim, ylim)]
425            [assert_array_almost_equal(pw, w, 14) for pw, w in zip(plot.width, pwidth)]
426
427
428def test_particle_plot_instance():
429    """
430    Tests the type of plot instance returned by ParticlePlot.
431
432    If x_field and y_field are any combination of valid particle_position in x,
433    y or z axis,then ParticleProjectionPlot instance is expected.
434
435
436    """
437    ds = fake_particle_ds()
438    x_field = ("all", "particle_position_x")
439    y_field = ("all", "particle_position_y")
440    z_field = ("all", "particle_velocity_x")
441
442    plot = ParticlePlot(ds, x_field, y_field)
443    assert isinstance(plot, ParticleProjectionPlot)
444
445    plot = ParticlePlot(ds, y_field, x_field)
446    assert isinstance(plot, ParticleProjectionPlot)
447
448    plot = ParticlePlot(ds, x_field, z_field)
449    assert isinstance(plot, ParticlePhasePlot)
450