1import base64
2import builtins
3import os
4from collections import OrderedDict
5from functools import wraps
6
7import matplotlib
8import numpy as np
9from more_itertools.more import always_iterable, unzip
10from packaging.version import parse as parse_version
11
12from yt.data_objects.profiles import create_profile, sanitize_field_tuple_keys
13from yt.data_objects.static_output import Dataset
14from yt.frontends.ytdata.data_structures import YTProfileDataset
15from yt.funcs import is_sequence, iter_fields, matplotlib_style_context
16from yt.utilities.exceptions import YTNotInsideNotebook
17from yt.utilities.logger import ytLogger as mylog
18
19from ..data_objects.selection_objects.data_selection_objects import YTSelectionContainer
20from ._commons import validate_image_name
21from .base_plot_types import ImagePlotMPL, PlotMPL
22from .plot_container import (
23    ImagePlotContainer,
24    get_log_minorticks,
25    invalidate_plot,
26    linear_transform,
27    log_transform,
28    validate_plot,
29)
30
31MPL_VERSION = parse_version(matplotlib.__version__)
32
33
34def invalidate_profile(f):
35    @wraps(f)
36    def newfunc(*args, **kwargs):
37        rv = f(*args, **kwargs)
38        args[0]._profile_valid = False
39        return rv
40
41    return newfunc
42
43
44class PlotContainerDict(OrderedDict):
45    def __missing__(self, key):
46        plot = PlotMPL((10, 8), [0.1, 0.1, 0.8, 0.8], None, None)
47        self[key] = plot
48        return self[key]
49
50
51class FigureContainer(OrderedDict):
52    def __init__(self, plots):
53        self.plots = plots
54        super().__init__()
55
56    def __missing__(self, key):
57        self[key] = self.plots[key].figure
58        return self[key]
59
60    def __iter__(self):
61        return iter(self.plots)
62
63
64class AxesContainer(OrderedDict):
65    def __init__(self, plots):
66        self.plots = plots
67        self.ylim = {}
68        self.xlim = (None, None)
69        super().__init__()
70
71    def __missing__(self, key):
72        self[key] = self.plots[key].axes
73        return self[key]
74
75    def __setitem__(self, key, value):
76        super().__setitem__(key, value)
77        self.ylim[key] = (None, None)
78
79
80def sanitize_label(labels, nprofiles):
81    labels = list(always_iterable(labels)) or [None]
82
83    if len(labels) == 1:
84        labels = labels * nprofiles
85
86    if len(labels) != nprofiles:
87        raise ValueError(
88            f"Number of labels {len(labels)} must match number of profiles {nprofiles}"
89        )
90
91    invalid_data = [
92        (label, type(label))
93        for label in labels
94        if label is not None and not isinstance(label, str)
95    ]
96    if invalid_data:
97        invalid_labels, types = unzip(invalid_data)
98        raise TypeError(
99            "All labels must be None or a string, "
100            f"received {invalid_labels} with type {types}"
101        )
102
103    return labels
104
105
106def data_object_or_all_data(data_source):
107    if isinstance(data_source, Dataset):
108        data_source = data_source.all_data()
109
110    if not isinstance(data_source, YTSelectionContainer):
111        raise RuntimeError("data_source must be a yt selection data object")
112
113    return data_source
114
115
116class ProfilePlot:
117    r"""
118    Create a 1d profile plot from a data source or from a list
119    of profile objects.
120
121    Given a data object (all_data, region, sphere, etc.), an x field,
122    and a y field (or fields), this will create a one-dimensional profile
123    of the average (or total) value of the y field in bins of the x field.
124
125    This can be used to create profiles from given fields or to plot
126    multiple profiles created from
127    `yt.data_objects.profiles.create_profile`.
128
129    Parameters
130    ----------
131    data_source : YTSelectionContainer Object
132        The data object to be profiled, such as all_data, region, or
133        sphere. If a dataset is passed in instead, an all_data data object
134        is generated internally from the dataset.
135    x_field : str
136        The binning field for the profile.
137    y_fields : str or list
138        The field or fields to be profiled.
139    weight_field : str
140        The weight field for calculating weighted averages. If None,
141        the profile values are the sum of the field values within the bin.
142        Otherwise, the values are a weighted average.
143        Default : ("gas", "mass")
144    n_bins : int
145        The number of bins in the profile.
146        Default: 64.
147    accumulation : bool
148        If True, the profile values for a bin N are the cumulative sum of
149        all the values from bin 0 to N.
150        Default: False.
151    fractional : If True the profile values are divided by the sum of all
152        the profile data such that the profile represents a probability
153        distribution function.
154    label : str or list of strings
155        If a string, the label to be put on the line plotted.  If a list,
156        this should be a list of labels for each profile to be overplotted.
157        Default: None.
158    plot_spec : dict or list of dicts
159        A dictionary or list of dictionaries containing plot keyword
160        arguments.  For example, dict(color="red", linestyle=":").
161        Default: None.
162    x_log : bool
163        Whether the x_axis should be plotted with a logarithmic
164        scaling (True), or linear scaling (False).
165        Default: True.
166    y_log : dict or bool
167        A dictionary containing field:boolean pairs, setting the logarithmic
168        property for that field. May be overridden after instantiation using
169        set_log
170        A single boolean can be passed to signify all fields should use
171        logarithmic (True) or linear scaling (False).
172        Default: True.
173
174    Examples
175    --------
176
177    This creates profiles of a single dataset.
178
179    >>> import yt
180    >>> ds = yt.load("enzo_tiny_cosmology/DD0046/DD0046")
181    >>> ad = ds.all_data()
182    >>> plot = yt.ProfilePlot(
183    ...     ad,
184    ...     ("gas", "density"),
185    ...     [("gas", "temperature"), ("gas", "velocity_x")],
186    ...     weight_field=("gas", "mass"),
187    ...     plot_spec=dict(color="red", linestyle="--"),
188    ... )
189    >>> plot.save()
190
191    This creates profiles from a time series object.
192
193    >>> es = yt.load_simulation("AMRCosmology.enzo", "Enzo")
194    >>> es.get_time_series()
195
196    >>> profiles = []
197    >>> labels = []
198    >>> plot_specs = []
199    >>> for ds in es[-4:]:
200    ...     ad = ds.all_data()
201    ...     profiles.append(
202    ...         create_profile(
203    ...             ad,
204    ...             [("gas", "density")],
205    ...             fields=[("gas", "temperature"), ("gas", "velocity_x")],
206    ...         )
207    ...     )
208    ...     labels.append(ds.current_redshift)
209    ...     plot_specs.append(dict(linestyle="--", alpha=0.7))
210
211    >>> plot = yt.ProfilePlot.from_profiles(
212    ...     profiles, labels=labels, plot_specs=plot_specs
213    ... )
214    >>> plot.save()
215
216    Use set_line_property to change line properties of one or all profiles.
217
218    """
219
220    x_log = None
221    y_log = None
222    x_title = None
223    y_title = None
224    _plot_valid = False
225
226    def __init__(
227        self,
228        data_source,
229        x_field,
230        y_fields,
231        weight_field=("gas", "mass"),
232        n_bins=64,
233        accumulation=False,
234        fractional=False,
235        label=None,
236        plot_spec=None,
237        x_log=True,
238        y_log=True,
239    ):
240
241        data_source = data_object_or_all_data(data_source)
242        y_fields = list(iter_fields(y_fields))
243        logs = {x_field: bool(x_log)}
244        if isinstance(y_log, bool):
245            y_log = {y_field: y_log for y_field in y_fields}
246
247        if isinstance(data_source.ds, YTProfileDataset):
248            profiles = [data_source.ds.profile]
249        else:
250            profiles = [
251                create_profile(
252                    data_source,
253                    [x_field],
254                    n_bins=[n_bins],
255                    fields=y_fields,
256                    weight_field=weight_field,
257                    accumulation=accumulation,
258                    fractional=fractional,
259                    logs=logs,
260                )
261            ]
262
263        if plot_spec is None:
264            plot_spec = [dict() for p in profiles]
265        if not isinstance(plot_spec, list):
266            plot_spec = [plot_spec.copy() for p in profiles]
267
268        ProfilePlot._initialize_instance(self, profiles, label, plot_spec, y_log)
269
270    @validate_plot
271    def save(self, name=None, suffix=".png", mpl_kwargs=None):
272        r"""
273        Saves a 1d profile plot.
274
275        Parameters
276        ----------
277        name : str
278            The output file keyword.
279        suffix : string
280            Specify the image type by its suffix. If not specified, the output
281            type will be inferred from the filename. Defaults to PNG.
282        mpl_kwargs : dict
283            A dict of keyword arguments to be passed to matplotlib.
284        """
285        if not self._plot_valid:
286            self._setup_plots()
287        unique = set(self.plots.values())
288        if len(unique) < len(self.plots):
289            iters = zip(range(len(unique)), sorted(unique))
290        else:
291            iters = self.plots.items()
292
293        if name is None:
294            if len(self.profiles) == 1:
295                name = str(self.profiles[0].ds)
296            else:
297                name = "Multi-data"
298
299        name = validate_image_name(name, suffix)
300        prefix, suffix = os.path.splitext(name)
301
302        xfn = self.profiles[0].x_field
303        if isinstance(xfn, tuple):
304            xfn = xfn[1]
305
306        names = []
307        for uid, plot in iters:
308            if isinstance(uid, tuple):
309                uid = uid[1]
310            uid_name = f"{prefix}_1d-Profile_{xfn}_{uid}{suffix}"
311            names.append(uid_name)
312            mylog.info("Saving %s", uid_name)
313            with matplotlib_style_context():
314                plot.save(uid_name, mpl_kwargs=mpl_kwargs)
315        return names
316
317    @validate_plot
318    def show(self):
319        r"""This will send any existing plots to the IPython notebook.
320
321        If yt is being run from within an IPython session, and it is able to
322        determine this, this function will send any existing plots to the
323        notebook for display.
324
325        If yt can't determine if it's inside an IPython session, it will raise
326        YTNotInsideNotebook.
327
328        Examples
329        --------
330
331        >>> import yt
332        >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
333        >>> pp = ProfilePlot(ds.all_data(), ("gas", "density"), ("gas", "temperature"))
334        >>> pp.show()
335
336        """
337        if "__IPYTHON__" in dir(builtins):
338            from IPython.display import display
339
340            display(self)
341        else:
342            raise YTNotInsideNotebook
343
344    @validate_plot
345    def _repr_html_(self):
346        """Return an html representation of the plot object. Will display as a
347        png for each WindowPlotMPL instance in self.plots"""
348        ret = ""
349        unique = set(self.plots.values())
350        if len(unique) < len(self.plots):
351            iters = sorted(unique)
352        else:
353            iters = self.plots.values()
354        for plot in iters:
355            with matplotlib_style_context():
356                img = plot._repr_png_()
357            img = base64.b64encode(img).decode()
358            ret += (
359                r'<img style="max-width:100%;max-height:100%;" '
360                r'src="data:image/png;base64,{}"><br>'.format(img)
361            )
362        return ret
363
364    def _setup_plots(self):
365        if self._plot_valid:
366            return
367        for f in self.axes:
368            self.axes[f].cla()
369            if f in self._plot_text:
370                self.plots[f].axes.text(
371                    self._text_xpos[f],
372                    self._text_ypos[f],
373                    self._plot_text[f],
374                    fontproperties=self._font_properties,
375                    **self._text_kwargs[f],
376                )
377
378        for i, profile in enumerate(self.profiles):
379            for field, field_data in profile.items():
380                self.axes[field].plot(
381                    np.array(profile.x),
382                    np.array(field_data),
383                    label=self.label[i],
384                    **self.plot_spec[i],
385                )
386
387        for profile in self.profiles:
388            for fname in profile.keys():
389                axes = self.axes[fname]
390                xscale, yscale = self._get_field_log(fname, profile)
391                xtitle, ytitle = self._get_field_title(fname, profile)
392
393                axes.set_xscale(xscale)
394                axes.set_yscale(yscale)
395
396                axes.set_ylabel(ytitle)
397                axes.set_xlabel(xtitle)
398
399                axes.set_ylim(*self.axes.ylim[fname])
400                axes.set_xlim(*self.axes.xlim)
401
402                if fname in self._plot_title:
403                    axes.set_title(self._plot_title[fname])
404
405                if any(self.label):
406                    axes.legend(loc="best")
407        self._set_font_properties()
408        self._plot_valid = True
409
410    @classmethod
411    def _initialize_instance(cls, obj, profiles, labels, plot_specs, y_log):
412        obj._plot_title = {}
413        obj._plot_text = {}
414        obj._text_xpos = {}
415        obj._text_ypos = {}
416        obj._text_kwargs = {}
417
418        from matplotlib.font_manager import FontProperties
419
420        obj._font_properties = FontProperties(family="stixgeneral", size=18)
421        obj._font_color = None
422        obj.profiles = list(always_iterable(profiles))
423        obj.x_log = None
424        obj.y_log = sanitize_field_tuple_keys(y_log, obj.profiles[0].data_source) or {}
425        obj.y_title = {}
426        obj.x_title = None
427        obj.label = sanitize_label(labels, len(obj.profiles))
428        if plot_specs is None:
429            plot_specs = [dict() for p in obj.profiles]
430        obj.plot_spec = plot_specs
431        obj.plots = PlotContainerDict()
432        obj.figures = FigureContainer(obj.plots)
433        obj.axes = AxesContainer(obj.plots)
434        obj._setup_plots()
435        return obj
436
437    @classmethod
438    def from_profiles(cls, profiles, labels=None, plot_specs=None, y_log=None):
439        r"""
440        Instantiate a ProfilePlot object from a list of profiles
441        created with :func:`~yt.data_objects.profiles.create_profile`.
442
443        Parameters
444        ----------
445        profiles : a profile or list of profiles
446            A single profile or list of profile objects created with
447            :func:`~yt.data_objects.profiles.create_profile`.
448        labels : list of strings
449            A list of labels for each profile to be overplotted.
450            Default: None.
451        plot_specs : list of dicts
452            A list of dictionaries containing plot keyword
453            arguments.  For example, [dict(color="red", linestyle=":")].
454            Default: None.
455
456        Examples
457        --------
458
459        >>> from yt import load_simulation
460        >>> es = load_simulation("AMRCosmology.enzo", "Enzo")
461        >>> es.get_time_series()
462
463        >>> profiles = []
464        >>> labels = []
465        >>> plot_specs = []
466        >>> for ds in es[-4:]:
467        ...     ad = ds.all_data()
468        ...     profiles.append(
469        ...         create_profile(
470        ...             ad,
471        ...             [("gas", "density")],
472        ...             fields=[("gas", "temperature"), ("gas", "velocity_x")],
473        ...         )
474        ...     )
475        ...     labels.append(ds.current_redshift)
476        ...     plot_specs.append(dict(linestyle="--", alpha=0.7))
477        >>> plot = ProfilePlot.from_profiles(
478        ...     profiles, labels=labels, plot_specs=plot_specs
479        ... )
480        >>> plot.save()
481
482        """
483        if labels is not None and len(profiles) != len(labels):
484            raise RuntimeError("Profiles list and labels list must be the same size.")
485        if plot_specs is not None and len(plot_specs) != len(profiles):
486            raise RuntimeError(
487                "Profiles list and plot_specs list must be the same size."
488            )
489        obj = cls.__new__(cls)
490        return cls._initialize_instance(obj, profiles, labels, plot_specs, y_log)
491
492    @invalidate_plot
493    def set_line_property(self, property, value, index=None):
494        r"""
495        Set properties for one or all lines to be plotted.
496
497        Parameters
498        ----------
499        property : str
500            The line property to be set.
501        value : str, int, float
502            The value to set for the line property.
503        index : int
504            The index of the profile in the list of profiles to be
505            changed.  If None, change all plotted lines.
506            Default : None.
507
508        Examples
509        --------
510
511        Change all the lines in a plot
512        plot.set_line_property("linestyle", "-")
513
514        Change a single line.
515        plot.set_line_property("linewidth", 4, index=0)
516
517        """
518        if index is None:
519            specs = self.plot_spec
520        else:
521            specs = [self.plot_spec[index]]
522        for spec in specs:
523            spec[property] = value
524        return self
525
526    @invalidate_plot
527    def set_log(self, field, log):
528        """set a field to log or linear.
529
530        Parameters
531        ----------
532        field : string
533            the field to set a transform
534        log : boolean
535            Log on/off.
536        """
537        if field == "all":
538            self.x_log = log
539            for field in list(self.profiles[0].field_data.keys()):
540                self.y_log[field] = log
541        else:
542            (field,) = self.profiles[0].data_source._determine_fields([field])
543            if field == self.profiles[0].x_field:
544                self.x_log = log
545            elif field in self.profiles[0].field_data:
546                self.y_log[field] = log
547            else:
548                raise KeyError(f"Field {field} not in profile plot!")
549        return self
550
551    @invalidate_plot
552    def set_ylabel(self, field, label):
553        """Sets a new ylabel for the specified fields
554
555        Parameters
556        ----------
557        field : string
558           The name of the field that is to be changed.
559
560        label : string
561           The label to be placed on the y-axis
562        """
563        if field == "all":
564            for field in self.profiles[0].field_data:
565                self.y_title[field] = label
566        else:
567            (field,) = self.profiles[0].data_source._determine_fields([field])
568            if field in self.profiles[0].field_data:
569                self.y_title[field] = label
570            else:
571                raise KeyError(f"Field {field} not in profile plot!")
572
573        return self
574
575    @invalidate_plot
576    def set_xlabel(self, label):
577        """Sets a new xlabel for all profiles
578
579        Parameters
580        ----------
581        label : string
582           The label to be placed on the x-axis
583        """
584        self.x_title = label
585
586        return self
587
588    @invalidate_plot
589    def set_unit(self, field, unit):
590        """Sets a new unit for the requested field
591
592        Parameters
593        ----------
594        field : string
595           The name of the field that is to be changed.
596
597        unit : string or Unit object
598           The name of the new unit.
599        """
600        fd = self.profiles[0].data_source._determine_fields(field)[0]
601        for profile in self.profiles:
602            if fd == profile.x_field:
603                profile.set_x_unit(unit)
604            elif fd[1] in self.profiles[0].field_map:
605                profile.set_field_unit(field, unit)
606            else:
607                raise KeyError(f"Field {field} not in profile plot!")
608        return self
609
610    @invalidate_plot
611    def set_xlim(self, xmin=None, xmax=None):
612        """Sets the limits of the bin field
613
614        Parameters
615        ----------
616
617        xmin : float or None
618          The new x minimum.  Defaults to None, which leaves the xmin
619          unchanged.
620
621        xmax : float or None
622          The new x maximum.  Defaults to None, which leaves the xmax
623          unchanged.
624
625        Examples
626        --------
627
628        >>> import yt
629        >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
630        >>> pp = yt.ProfilePlot(
631        ...     ds.all_data(), ("gas", "density"), ("gas", "temperature")
632        ... )
633        >>> pp.set_xlim(1e-29, 1e-24)
634        >>> pp.save()
635
636        """
637        self.axes.xlim = (xmin, xmax)
638        for i, p in enumerate(self.profiles):
639            if xmin is None:
640                xmi = p.x_bins.min()
641            else:
642                xmi = xmin
643            if xmax is None:
644                xma = p.x_bins.max()
645            else:
646                xma = xmax
647            extrema = {p.x_field: ((xmi, str(p.x.units)), (xma, str(p.x.units)))}
648            units = {p.x_field: str(p.x.units)}
649            if self.x_log is None:
650                logs = None
651            else:
652                logs = {p.x_field: self.x_log}
653            for field in p.field_map.values():
654                units[field] = str(p.field_data[field].units)
655            self.profiles[i] = create_profile(
656                p.data_source,
657                p.x_field,
658                n_bins=len(p.x_bins) - 1,
659                fields=list(p.field_map.values()),
660                weight_field=p.weight_field,
661                accumulation=p.accumulation,
662                fractional=p.fractional,
663                logs=logs,
664                extrema=extrema,
665                units=units,
666            )
667        return self
668
669    @invalidate_plot
670    def set_ylim(self, field, ymin=None, ymax=None):
671        """Sets the plot limits for the specified field we are binning.
672
673        Parameters
674        ----------
675
676        field : string or field tuple
677
678        The field that we want to adjust the plot limits for.
679
680        ymin : float or None
681          The new y minimum.  Defaults to None, which leaves the ymin
682          unchanged.
683
684        ymax : float or None
685          The new y maximum.  Defaults to None, which leaves the ymax
686          unchanged.
687
688        Examples
689        --------
690
691        >>> import yt
692        >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
693        >>> pp = yt.ProfilePlot(
694        ...     ds.all_data(),
695        ...     ("gas", "density"),
696        ...     [("gas", "temperature"), ("gas", "velocity_x")],
697        ... )
698        >>> pp.set_ylim(("gas", "temperature"), 1e4, 1e6)
699        >>> pp.save()
700
701        """
702        fields = list(self.axes.keys()) if field == "all" else field
703        for profile in self.profiles:
704            for field in profile.data_source._determine_fields(fields):
705                if field in profile.field_map:
706                    field = profile.field_map[field]
707                self.axes.ylim[field] = (ymin, ymax)
708                # Continue on to the next profile.
709                break
710        return self
711
712    def _set_font_properties(self):
713        for f in self.plots:
714            self.plots[f]._set_font_properties(self._font_properties, self._font_color)
715
716    def _get_field_log(self, field_y, profile):
717        yfi = profile.field_info[field_y]
718        if self.x_log is None:
719            x_log = profile.x_log
720        else:
721            x_log = self.x_log
722        y_log = self.y_log.get(field_y, yfi.take_log)
723        scales = {True: "log", False: "linear"}
724        return scales[x_log], scales[y_log]
725
726    def _get_field_label(self, field, field_info, field_unit, fractional=False):
727        field_unit = field_unit.latex_representation()
728        field_name = field_info.display_name
729        if isinstance(field, tuple):
730            field = field[1]
731        if field_name is None:
732            field_name = r"$\rm{" + field + r"}$"
733            field_name = r"$\rm{" + field.replace("_", r"\ ").title() + r"}$"
734        elif field_name.find("$") == -1:
735            field_name = field_name.replace(" ", r"\ ")
736            field_name = r"$\rm{" + field_name + r"}$"
737        if fractional:
738            label = field_name + r"$\rm{\ Probability\ Density}$"
739        elif field_unit is None or field_unit == "":
740            label = field_name
741        else:
742            label = field_name + r"$\ \ (" + field_unit + r")$"
743        return label
744
745    def _get_field_title(self, field_y, profile):
746        field_x = profile.x_field
747        xfi = profile.field_info[field_x]
748        yfi = profile.field_info[field_y]
749        x_unit = profile.x.units
750        y_unit = profile.field_units[field_y]
751        fractional = profile.fractional
752        x_title = self.x_title or self._get_field_label(field_x, xfi, x_unit)
753        y_title = self.y_title.get(field_y, None) or self._get_field_label(
754            field_y, yfi, y_unit, fractional
755        )
756
757        return (x_title, y_title)
758
759    @invalidate_plot
760    def annotate_title(self, title, field="all"):
761        r"""Set a title for the plot.
762
763        Parameters
764        ----------
765        title : str
766          The title to add.
767        field : str or list of str
768          The field name for which title needs to be set.
769
770        Examples
771        --------
772        >>> # To set title for all the fields:
773        >>> plot.annotate_title("This is a Profile Plot")
774
775        >>> # To set title for specific fields:
776        >>> plot.annotate_title("Profile Plot for Temperature", ("gas", "temperature"))
777
778        >>> # Setting same plot title for both the given fields
779        >>> plot.annotate_title(
780        ...     "Profile Plot: Temperature-Dark Matter Density",
781        ...     [("gas", "temperature"), ("deposit", "dark_matter_density")],
782        ... )
783
784        """
785        fields = list(self.axes.keys()) if field == "all" else field
786        for profile in self.profiles:
787            for field in profile.data_source._determine_fields(fields):
788                if field in profile.field_map:
789                    field = profile.field_map[field]
790                self._plot_title[field] = title
791        return self
792
793    @invalidate_plot
794    def annotate_text(self, xpos=0.0, ypos=0.0, text=None, field="all", **text_kwargs):
795        r"""Allow the user to insert text onto the plot
796
797        The x-position and y-position must be given as well as the text string.
798        Add *text* to plot at location *xpos*, *ypos* in plot coordinates for
799        the given fields or by default for all fields.
800        (see example below).
801
802        Parameters
803        ----------
804        xpos : float
805          Position on plot in x-coordinates.
806        ypos : float
807          Position on plot in y-coordinates.
808        text : str
809          The text to insert onto the plot.
810        field : str or tuple
811          The name of the field to add text to.
812        **text_kwargs : dict
813          Extra keyword arguments will be passed to matplotlib text instance
814
815        >>> import yt
816        >>> from yt.units import kpc
817        >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
818        >>> my_galaxy = ds.disk(ds.domain_center, [0.0, 0.0, 1.0], 10 * kpc, 3 * kpc)
819        >>> plot = yt.ProfilePlot(
820        ...     my_galaxy, ("gas", "density"), [("gas", "temperature")]
821        ... )
822
823        >>> # Annotate text for all the fields
824        >>> plot.annotate_text(1e-26, 1e5, "This is annotated text in the plot area.")
825        >>> plot.save()
826
827        >>> # Annotate text for a given field
828        >>> plot.annotate_text(1e-26, 1e5, "Annotated text", ("gas", "temperature"))
829        >>> plot.save()
830
831        >>> # Annotate text for multiple fields
832        >>> fields = [("gas", "temperature"), ("gas", "density")]
833        >>> plot.annotate_text(1e-26, 1e5, "Annotated text", fields)
834        >>> plot.save()
835
836        """
837        fields = list(self.axes.keys()) if field == "all" else field
838        for profile in self.profiles:
839            for field in profile.data_source._determine_fields(fields):
840                if field in profile.field_map:
841                    field = profile.field_map[field]
842                self._plot_text[field] = text
843                self._text_xpos[field] = xpos
844                self._text_ypos[field] = ypos
845                self._text_kwargs[field] = text_kwargs
846        return self
847
848
849class PhasePlot(ImagePlotContainer):
850    r"""
851    Create a 2d profile (phase) plot from a data source or from
852    profile object created with
853    `yt.data_objects.profiles.create_profile`.
854
855    Given a data object (all_data, region, sphere, etc.), an x field,
856    y field, and z field (or fields), this will create a two-dimensional
857    profile of the average (or total) value of the z field in bins of the
858    x and y fields.
859
860    Parameters
861    ----------
862    data_source : YTSelectionContainer Object
863        The data object to be profiled, such as all_data, region, or
864        sphere. If a dataset is passed in instead, an all_data data object
865        is generated internally from the dataset.
866    x_field : str
867        The x binning field for the profile.
868    y_field : str
869        The y binning field for the profile.
870    z_fields : str or list
871        The field or fields to be profiled.
872    weight_field : str
873        The weight field for calculating weighted averages.  If None,
874        the profile values are the sum of the field values within the bin.
875        Otherwise, the values are a weighted average.
876        Default : ("gas", "mass")
877    x_bins : int
878        The number of bins in x field for the profile.
879        Default: 128.
880    y_bins : int
881        The number of bins in y field for the profile.
882        Default: 128.
883    accumulation : bool or list of bools
884        If True, the profile values for a bin n are the cumulative sum of
885        all the values from bin 0 to n.  If -True, the sum is reversed so
886        that the value for bin n is the cumulative sum from bin N (total bins)
887        to n.  A list of values can be given to control the summation in each
888        dimension independently.
889        Default: False.
890    fractional : If True the profile values are divided by the sum of all
891        the profile data such that the profile represents a probability
892        distribution function.
893    fontsize : int
894        Font size for all text in the plot.
895        Default: 18.
896    figure_size : int
897        Size in inches of the image.
898        Default: 8 (8x8)
899    shading : str
900        This argument is directly passed down to matplotlib.axes.Axes.pcolormesh
901        see
902        https://matplotlib.org/3.3.1/gallery/images_contours_and_fields/pcolormesh_grids.html#sphx-glr-gallery-images-contours-and-fields-pcolormesh-grids-py  # noqa
903        Default: 'nearest'
904
905    Examples
906    --------
907
908    >>> import yt
909    >>> ds = yt.load("enzo_tiny_cosmology/DD0046/DD0046")
910    >>> ad = ds.all_data()
911    >>> plot = yt.PhasePlot(
912    ...     ad,
913    ...     ("gas", "density"),
914    ...     ("gas", "temperature"),
915    ...     [("gas", "mass")],
916    ...     weight_field=None,
917    ... )
918    >>> plot.save()
919
920    >>> # Change plot properties.
921    >>> plot.set_cmap(("gas", "mass"), "jet")
922    >>> plot.set_zlim(("gas", "mass"), 1e8, 1e13)
923    >>> plot.annotate_title("This is a phase plot")
924
925    """
926    x_log = None
927    y_log = None
928    plot_title = None
929    _plot_valid = False
930    _profile_valid = False
931    _plot_type = "Phase"
932    _xlim = (None, None)
933    _ylim = (None, None)
934
935    def __init__(
936        self,
937        data_source,
938        x_field,
939        y_field,
940        z_fields,
941        weight_field=("gas", "mass"),
942        x_bins=128,
943        y_bins=128,
944        accumulation=False,
945        fractional=False,
946        fontsize=18,
947        figure_size=8.0,
948        shading="nearest",
949    ):
950
951        data_source = data_object_or_all_data(data_source)
952
953        if isinstance(z_fields, tuple):
954            z_fields = [z_fields]
955        z_fields = list(always_iterable(z_fields))
956
957        if isinstance(data_source.ds, YTProfileDataset):
958            profile = data_source.ds.profile
959        else:
960            profile = create_profile(
961                data_source,
962                [x_field, y_field],
963                z_fields,
964                n_bins=[x_bins, y_bins],
965                weight_field=weight_field,
966                accumulation=accumulation,
967                fractional=fractional,
968            )
969
970        type(self)._initialize_instance(
971            self, data_source, profile, fontsize, figure_size, shading
972        )
973
974    @classmethod
975    def _initialize_instance(
976        cls, obj, data_source, profile, fontsize, figure_size, shading
977    ):
978        obj.plot_title = {}
979        obj.z_log = {}
980        obj.z_title = {}
981        obj._initfinished = False
982        obj.x_log = None
983        obj.y_log = None
984        obj._plot_text = {}
985        obj._text_xpos = {}
986        obj._text_ypos = {}
987        obj._text_kwargs = {}
988        obj._profile = profile
989        obj._shading = shading
990        obj._profile_valid = True
991        obj._xlim = (None, None)
992        obj._ylim = (None, None)
993        super(PhasePlot, obj).__init__(data_source, figure_size, fontsize)
994        obj._setup_plots()
995        obj._initfinished = True
996        return obj
997
998    def _get_field_title(self, field_z, profile):
999        field_x = profile.x_field
1000        field_y = profile.y_field
1001        xfi = profile.field_info[field_x]
1002        yfi = profile.field_info[field_y]
1003        zfi = profile.field_info[field_z]
1004        x_unit = profile.x.units
1005        y_unit = profile.y.units
1006        z_unit = profile.field_units[field_z]
1007        fractional = profile.fractional
1008        x_label, y_label, z_label = self._get_axes_labels(field_z)
1009        x_title = x_label or self._get_field_label(field_x, xfi, x_unit)
1010        y_title = y_label or self._get_field_label(field_y, yfi, y_unit)
1011        z_title = z_label or self._get_field_label(field_z, zfi, z_unit, fractional)
1012        return (x_title, y_title, z_title)
1013
1014    def _get_field_label(self, field, field_info, field_unit, fractional=False):
1015        field_unit = field_unit.latex_representation()
1016        field_name = field_info.display_name
1017        if isinstance(field, tuple):
1018            field = field[1]
1019        if field_name is None:
1020            field_name = r"$\rm{" + field + r"}$"
1021            field_name = r"$\rm{" + field.replace("_", r"\ ").title() + r"}$"
1022        elif field_name.find("$") == -1:
1023            field_name = field_name.replace(" ", r"\ ")
1024            field_name = r"$\rm{" + field_name + r"}$"
1025        if fractional:
1026            label = field_name + r"$\rm{\ Probability\ Density}$"
1027        elif field_unit is None or field_unit == "":
1028            label = field_name
1029        else:
1030            label = field_name + r"$\ \ (" + field_unit + r")$"
1031        return label
1032
1033    def _get_field_log(self, field_z, profile):
1034        zfi = profile.field_info[field_z]
1035        if self.x_log is None:
1036            x_log = profile.x_log
1037        else:
1038            x_log = self.x_log
1039        if self.y_log is None:
1040            y_log = profile.y_log
1041        else:
1042            y_log = self.y_log
1043        if field_z in self.z_log:
1044            z_log = self.z_log[field_z]
1045        else:
1046            z_log = zfi.take_log
1047        scales = {True: "log", False: "linear"}
1048        return scales[x_log], scales[y_log], scales[z_log]
1049
1050    def _recreate_frb(self):
1051        # needed for API compatibility with PlotWindow
1052        pass
1053
1054    @property
1055    def profile(self):
1056        if not self._profile_valid:
1057            self._recreate_profile()
1058        return self._profile
1059
1060    @property
1061    def fields(self):
1062        return list(self.plots.keys())
1063
1064    def _setup_plots(self):
1065        if self._plot_valid:
1066            return
1067        for f, data in self.profile.items():
1068            fig = None
1069            axes = None
1070            cax = None
1071            draw_colorbar = True
1072            draw_axes = True
1073            zlim = (None, None)
1074            xlim = self._xlim
1075            ylim = self._ylim
1076            if f in self.plots:
1077                draw_colorbar = self.plots[f]._draw_colorbar
1078                draw_axes = self.plots[f]._draw_axes
1079                zlim = (self.plots[f].zmin, self.plots[f].zmax)
1080                if self.plots[f].figure is not None:
1081                    fig = self.plots[f].figure
1082                    axes = self.plots[f].axes
1083                    cax = self.plots[f].cax
1084
1085            x_scale, y_scale, z_scale = self._get_field_log(f, self.profile)
1086            x_title, y_title, z_title = self._get_field_title(f, self.profile)
1087
1088            if zlim == (None, None):
1089                if z_scale == "log":
1090                    positive_values = data[data > 0.0]
1091                    if len(positive_values) == 0:
1092                        mylog.warning(
1093                            "Profiled field %s has no positive values. Max = %f.",
1094                            f,
1095                            np.nanmax(data),
1096                        )
1097                        mylog.warning("Switching to linear colorbar scaling.")
1098                        zmin = np.nanmin(data)
1099                        z_scale = "linear"
1100                        self._field_transform[f] = linear_transform
1101                    else:
1102                        zmin = positive_values.min()
1103                        self._field_transform[f] = log_transform
1104                else:
1105                    zmin = np.nanmin(data)
1106                    self._field_transform[f] = linear_transform
1107                zlim = [zmin, np.nanmax(data)]
1108
1109            font_size = self._font_properties.get_size()
1110            f = self.profile.data_source._determine_fields(f)[0]
1111
1112            # if this is a Particle Phase Plot AND if we using a single color,
1113            # override the colorbar here.
1114            splat_color = getattr(self, "splat_color", None)
1115            if splat_color is not None:
1116                cmap = matplotlib.colors.ListedColormap(splat_color, "dummy")
1117            else:
1118                cmap = self._colormap_config[f]
1119
1120            self.plots[f] = PhasePlotMPL(
1121                self.profile.x,
1122                self.profile.y,
1123                data,
1124                x_scale,
1125                y_scale,
1126                z_scale,
1127                cmap,
1128                zlim,
1129                self.figure_size,
1130                font_size,
1131                fig,
1132                axes,
1133                cax,
1134                shading=self._shading,
1135            )
1136
1137            self.plots[f]._toggle_axes(draw_axes)
1138            self.plots[f]._toggle_colorbar(draw_colorbar)
1139
1140            self.plots[f].axes.xaxis.set_label_text(x_title)
1141            self.plots[f].axes.yaxis.set_label_text(y_title)
1142            self.plots[f].cax.yaxis.set_label_text(z_title)
1143
1144            self.plots[f].axes.set_xlim(xlim)
1145            self.plots[f].axes.set_ylim(ylim)
1146
1147            color = self._background_color[f]
1148
1149            self.plots[f].axes.set_facecolor(color)
1150
1151            if f in self._plot_text:
1152                self.plots[f].axes.text(
1153                    self._text_xpos[f],
1154                    self._text_ypos[f],
1155                    self._plot_text[f],
1156                    fontproperties=self._font_properties,
1157                    **self._text_kwargs[f],
1158                )
1159
1160            if f in self.plot_title:
1161                self.plots[f].axes.set_title(self.plot_title[f])
1162
1163            # x-y axes minorticks
1164            if f not in self._minorticks:
1165                self._minorticks[f] = True
1166            if self._minorticks[f]:
1167                self.plots[f].axes.minorticks_on()
1168            else:
1169                self.plots[f].axes.minorticks_off()
1170
1171            # colorbar minorticks
1172            if f not in self._cbar_minorticks:
1173                self._cbar_minorticks[f] = True
1174            if self._cbar_minorticks[f]:
1175                if self._field_transform[f] == linear_transform:
1176                    self.plots[f].cax.minorticks_on()
1177                elif MPL_VERSION < parse_version("3.0.0"):
1178                    # before matplotlib 3 log-scaled colorbars internally used
1179                    # a linear scale going from zero to one and did not draw
1180                    # minor ticks. Since we want minor ticks, calculate
1181                    # where the minor ticks should go in this linear scale
1182                    # and add them manually.
1183                    vmin = np.float64(self.plots[f].cb.norm.vmin)
1184                    vmax = np.float64(self.plots[f].cb.norm.vmax)
1185                    mticks = self.plots[f].image.norm(get_log_minorticks(vmin, vmax))
1186                    self.plots[f].cax.yaxis.set_ticks(mticks, minor=True)
1187            else:
1188                self.plots[f].cax.minorticks_off()
1189
1190        self._set_font_properties()
1191
1192        # if this is a particle plot with one color only, hide the cbar here
1193        if hasattr(self, "use_cbar") and not self.use_cbar:
1194            self.plots[f].hide_colorbar()
1195
1196        self._plot_valid = True
1197
1198    @classmethod
1199    def from_profile(cls, profile, fontsize=18, figure_size=8.0, shading="nearest"):
1200        r"""
1201        Instantiate a PhasePlot object from a profile object created
1202        with :func:`~yt.data_objects.profiles.create_profile`.
1203
1204        Parameters
1205        ----------
1206        profile : An instance of :class:`~yt.data_objects.profiles.ProfileND`
1207             A single profile object.
1208        fontsize : float
1209             The fontsize to use, in points.
1210        figure_size : float
1211             The figure size to use, in inches.
1212        shading : str
1213            This argument is directly passed down to matplotlib.axes.Axes.pcolormesh
1214            see
1215            https://matplotlib.org/3.3.1/gallery/images_contours_and_fields/pcolormesh_grids.html#sphx-glr-gallery-images-contours-and-fields-pcolormesh-grids-py  # noqa
1216            Default: 'nearest'
1217
1218        Examples
1219        --------
1220
1221        >>> import yt
1222        >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1223        >>> extrema = {
1224        ...     ("gas", "density"): (1e-31, 1e-24),
1225        ...     ("gas", "temperature"): (1e1, 1e8),
1226        ...     ("gas", "mass"): (1e-6, 1e-1),
1227        ... }
1228        >>> profile = yt.create_profile(
1229        ...     ds.all_data(),
1230        ...     [("gas", "density"), ("gas", "temperature")],
1231        ...     fields=[("gas", "mass")],
1232        ...     extrema=extrema,
1233        ...     fractional=True,
1234        ... )
1235        >>> ph = yt.PhasePlot.from_profile(profile)
1236        >>> ph.save()
1237        """
1238        obj = cls.__new__(cls)
1239        data_source = profile.data_source
1240        return cls._initialize_instance(
1241            obj, data_source, profile, fontsize, figure_size, shading
1242        )
1243
1244    def annotate_text(self, xpos=0.0, ypos=0.0, text=None, **text_kwargs):
1245        r"""
1246        Allow the user to insert text onto the plot
1247        The x-position and y-position must be given as well as the text string.
1248        Add *text* tp plot at location *xpos*, *ypos* in plot coordinates
1249        (see example below).
1250
1251        Parameters
1252        ----------
1253        xpos : float
1254          Position on plot in x-coordinates.
1255        ypos : float
1256          Position on plot in y-coordinates.
1257        text : str
1258          The text to insert onto the plot.
1259        **text_kwargs : dict
1260          Extra keyword arguments will be passed to matplotlib text instance
1261
1262        >>> plot.annotate_text(1e-15, 5e4, "Hello YT")
1263
1264        """
1265        for f in self.data_source._determine_fields(list(self.plots.keys())):
1266            if self.plots[f].figure is not None and text is not None:
1267                self.plots[f].axes.text(
1268                    xpos,
1269                    ypos,
1270                    text,
1271                    fontproperties=self._font_properties,
1272                    **text_kwargs,
1273                )
1274            self._plot_text[f] = text
1275            self._text_xpos[f] = xpos
1276            self._text_ypos[f] = ypos
1277            self._text_kwargs[f] = text_kwargs
1278        return self
1279
1280    @validate_plot
1281    def save(self, name=None, suffix=".png", mpl_kwargs=None):
1282        r"""
1283        Saves a 2d profile plot.
1284
1285        Parameters
1286        ----------
1287        name : str
1288            The output file keyword.
1289        suffix : string
1290           Specify the image type by its suffix. If not specified, the output
1291           type will be inferred from the filename. Defaults to PNG.
1292        mpl_kwargs : dict
1293           A dict of keyword arguments to be passed to matplotlib.
1294
1295        >>> plot.save(mpl_kwargs={"bbox_inches": "tight"})
1296
1297        """
1298        names = []
1299        if not self._plot_valid:
1300            self._setup_plots()
1301        if mpl_kwargs is None:
1302            mpl_kwargs = {}
1303        if name is None:
1304            name = str(self.profile.ds)
1305        name = os.path.expanduser(name)
1306        xfn = self.profile.x_field
1307        yfn = self.profile.y_field
1308        if isinstance(xfn, tuple):
1309            xfn = xfn[1]
1310        if isinstance(yfn, tuple):
1311            yfn = yfn[1]
1312        for f in self.profile.field_data:
1313            _f = f
1314            if isinstance(f, tuple):
1315                _f = _f[1]
1316            middle = f"2d-Profile_{xfn}_{yfn}_{_f}"
1317            splitname = os.path.split(name)
1318            if splitname[0] != "" and not os.path.isdir(splitname[0]):
1319                os.makedirs(splitname[0])
1320            if os.path.isdir(name) and name != str(self.profile.ds):
1321                name = name + (os.sep if name[-1] != os.sep else "")
1322                name += str(self.profile.ds)
1323
1324            new_name = validate_image_name(name, suffix)
1325            if new_name == name:
1326                for v in self.plots.values():
1327                    out_name = v.save(name, mpl_kwargs)
1328                    names.append(out_name)
1329                return names
1330
1331            name = new_name
1332            prefix, suffix = os.path.splitext(name)
1333            name = f"{prefix}_{middle}{suffix}"
1334
1335            names.append(name)
1336            self.plots[f].save(name, mpl_kwargs)
1337        return names
1338
1339    @invalidate_plot
1340    def set_font(self, font_dict=None):
1341        """
1342
1343        Set the font and font properties.
1344
1345        Parameters
1346        ----------
1347
1348        font_dict : dict
1349            A dict of keyword parameters to be passed to
1350            :class:`matplotlib.font_manager.FontProperties`.
1351
1352            Possible keys include:
1353
1354            * family - The font family. Can be serif, sans-serif, cursive,
1355              'fantasy', or 'monospace'.
1356            * style - The font style. Either normal, italic or oblique.
1357            * color - A valid color string like 'r', 'g', 'red', 'cobalt',
1358              and 'orange'.
1359            * variant - Either normal or small-caps.
1360            * size - Either a relative value of xx-small, x-small, small,
1361              medium, large, x-large, xx-large or an absolute font size, e.g. 12
1362            * stretch - A numeric value in the range 0-1000 or one of
1363              ultra-condensed, extra-condensed, condensed, semi-condensed,
1364              normal, semi-expanded, expanded, extra-expanded or ultra-expanded
1365            * weight - A numeric value in the range 0-1000 or one of ultralight,
1366              light, normal, regular, book, medium, roman, semibold, demibold,
1367              demi, bold, heavy, extra bold, or black
1368
1369            See the matplotlib font manager API documentation for more details.
1370            https://matplotlib.org/stable/api/font_manager_api.html
1371
1372        Notes
1373        -----
1374
1375        Mathtext axis labels will only obey the `size` and `color` keyword.
1376
1377        Examples
1378        --------
1379
1380        This sets the font to be 24-pt, blue, sans-serif, italic, and
1381        bold-face.
1382
1383        >>> prof = ProfilePlot(
1384        ...     ds.all_data(), ("gas", "density"), ("gas", "temperature")
1385        ... )
1386        >>> slc.set_font(
1387        ...     {
1388        ...         "family": "sans-serif",
1389        ...         "style": "italic",
1390        ...         "weight": "bold",
1391        ...         "size": 24,
1392        ...         "color": "blue",
1393        ...     }
1394        ... )
1395
1396        """
1397        from matplotlib.font_manager import FontProperties
1398
1399        if font_dict is None:
1400            font_dict = {}
1401        if "color" in font_dict:
1402            self._font_color = font_dict.pop("color")
1403        # Set default values if the user does not explicitly set them.
1404        # this prevents reverting to the matplotlib defaults.
1405        font_dict.setdefault("family", "stixgeneral")
1406        font_dict.setdefault("size", 18)
1407        self._font_properties = FontProperties(**font_dict)
1408        return self
1409
1410    @invalidate_plot
1411    def set_title(self, field, title):
1412        """Set a title for the plot.
1413
1414        Parameters
1415        ----------
1416        field : str
1417            The z field of the plot to add the title.
1418        title : str
1419            The title to add.
1420
1421        Examples
1422        --------
1423
1424        >>> plot.set_title(("gas", "mass"), "This is a phase plot")
1425        """
1426        self.plot_title[self.data_source._determine_fields(field)[0]] = title
1427        return self
1428
1429    @invalidate_plot
1430    def annotate_title(self, title):
1431        """Set a title for the plot.
1432
1433        Parameters
1434        ----------
1435        title : str
1436            The title to add.
1437
1438        Examples
1439        --------
1440
1441        >>> plot.annotate_title("This is a phase plot")
1442
1443        """
1444        for f in self._profile.field_data:
1445            if isinstance(f, tuple):
1446                f = f[1]
1447            self.plot_title[self.data_source._determine_fields(f)[0]] = title
1448        return self
1449
1450    @invalidate_plot
1451    def reset_plot(self):
1452        self.plots = {}
1453        return self
1454
1455    @invalidate_plot
1456    def set_log(self, field, log):
1457        """set a field to log or linear.
1458
1459        Parameters
1460        ----------
1461        field : string
1462            the field to set a transform
1463        log : boolean
1464            Log on/off.
1465        """
1466        p = self._profile
1467        if field == "all":
1468            self.x_log = log
1469            self.y_log = log
1470            for field in p.field_data:
1471                self.z_log[field] = log
1472            self._profile_valid = False
1473        else:
1474            (field,) = self.profile.data_source._determine_fields([field])
1475            if field == p.x_field:
1476                self.x_log = log
1477                self._profile_valid = False
1478            elif field == p.y_field:
1479                self.y_log = log
1480                self._profile_valid = False
1481            elif field in p.field_data:
1482                self.z_log[field] = log
1483            else:
1484                raise KeyError(f"Field {field} not in phase plot!")
1485        return self
1486
1487    @invalidate_plot
1488    def set_unit(self, field, unit):
1489        """Sets a new unit for the requested field
1490
1491        Parameters
1492        ----------
1493        field : string
1494           The name of the field that is to be changed.
1495
1496        unit : string or Unit object
1497           The name of the new unit.
1498        """
1499        fd = self.data_source._determine_fields(field)[0]
1500        if fd == self.profile.x_field:
1501            self.profile.set_x_unit(unit)
1502        elif fd == self.profile.y_field:
1503            self.profile.set_y_unit(unit)
1504        elif fd in self.profile.field_data.keys():
1505            self.profile.set_field_unit(field, unit)
1506            self.plots[field].zmin, self.plots[field].zmax = (None, None)
1507        else:
1508            raise KeyError(f"Field {field} not in phase plot!")
1509        return self
1510
1511    @invalidate_plot
1512    @invalidate_profile
1513    def set_xlim(self, xmin=None, xmax=None):
1514        """Sets the limits of the x bin field
1515
1516        Parameters
1517        ----------
1518
1519        xmin : float or None
1520          The new x minimum in the current x-axis units.  Defaults to None,
1521          which leaves the xmin unchanged.
1522
1523        xmax : float or None
1524          The new x maximum in the current x-axis units.  Defaults to None,
1525          which leaves the xmax unchanged.
1526
1527        Examples
1528        --------
1529
1530        >>> import yt
1531        >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1532        >>> pp = yt.PhasePlot(ds.all_data(), "density", "temperature", ("gas", "mass"))
1533        >>> pp.set_xlim(1e-29, 1e-24)
1534        >>> pp.save()
1535
1536        """
1537        p = self._profile
1538        if xmin is None:
1539            xmin = p.x_bins.min()
1540        elif not hasattr(xmin, "units"):
1541            xmin = self.ds.quan(xmin, p.x_bins.units)
1542        if xmax is None:
1543            xmax = p.x_bins.max()
1544        elif not hasattr(xmax, "units"):
1545            xmax = self.ds.quan(xmax, p.x_bins.units)
1546        self._xlim = (xmin, xmax)
1547        return self
1548
1549    @invalidate_plot
1550    @invalidate_profile
1551    def set_ylim(self, ymin=None, ymax=None):
1552        """Sets the plot limits for the y bin field.
1553
1554        Parameters
1555        ----------
1556
1557        ymin : float or None
1558          The new y minimum in the current y-axis units.  Defaults to None,
1559          which leaves the ymin unchanged.
1560
1561        ymax : float or None
1562          The new y maximum in the current y-axis units.  Defaults to None,
1563          which leaves the ymax unchanged.
1564
1565        Examples
1566        --------
1567
1568        >>> import yt
1569        >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1570        >>> pp = yt.PhasePlot(
1571        ...     ds.all_data(),
1572        ...     ("gas", "density"),
1573        ...     ("gas", "temperature"),
1574        ...     ("gas", "mass"),
1575        ... )
1576        >>> pp.set_ylim(1e4, 1e6)
1577        >>> pp.save()
1578
1579        """
1580        p = self._profile
1581        if ymin is None:
1582            ymin = p.y_bins.min()
1583        elif not hasattr(ymin, "units"):
1584            ymin = self.ds.quan(ymin, p.y_bins.units)
1585        if ymax is None:
1586            ymax = p.y_bins.max()
1587        elif not hasattr(ymax, "units"):
1588            ymax = self.ds.quan(ymax, p.y_bins.units)
1589        self._ylim = (ymin, ymax)
1590        return self
1591
1592    def _recreate_profile(self):
1593        p = self._profile
1594        units = {p.x_field: str(p.x.units), p.y_field: str(p.y.units)}
1595        zunits = {field: str(p.field_units[field]) for field in p.field_units}
1596        extrema = {p.x_field: self._xlim, p.y_field: self._ylim}
1597        if self.x_log is not None or self.y_log is not None:
1598            logs = {}
1599        else:
1600            logs = None
1601        if self.x_log is not None:
1602            logs[p.x_field] = self.x_log
1603        if self.y_log is not None:
1604            logs[p.y_field] = self.y_log
1605        deposition = getattr(p, "deposition", None)
1606        additional_kwargs = {
1607            "accumulation": p.accumulation,
1608            "fractional": p.fractional,
1609            "deposition": deposition,
1610        }
1611        self._profile = create_profile(
1612            p.data_source,
1613            [p.x_field, p.y_field],
1614            list(p.field_map.values()),
1615            n_bins=[len(p.x_bins) - 1, len(p.y_bins) - 1],
1616            weight_field=p.weight_field,
1617            units=units,
1618            extrema=extrema,
1619            logs=logs,
1620            **additional_kwargs,
1621        )
1622        for field in zunits:
1623            self._profile.set_field_unit(field, zunits[field])
1624        self._profile_valid = True
1625
1626
1627class PhasePlotMPL(ImagePlotMPL):
1628    """A container for a single matplotlib figure and axes for a PhasePlot"""
1629
1630    def __init__(
1631        self,
1632        x_data,
1633        y_data,
1634        data,
1635        x_scale,
1636        y_scale,
1637        z_scale,
1638        cmap,
1639        zlim,
1640        figure_size,
1641        fontsize,
1642        figure,
1643        axes,
1644        cax,
1645        shading="nearest",
1646    ):
1647        self._initfinished = False
1648        self._draw_colorbar = True
1649        self._draw_axes = True
1650        self._figure_size = figure_size
1651        self._shading = shading
1652        # Compute layout
1653        fontscale = float(fontsize) / 18.0
1654        if fontscale < 1.0:
1655            fontscale = np.sqrt(fontscale)
1656
1657        if is_sequence(figure_size):
1658            self._cb_size = 0.0375 * figure_size[0]
1659        else:
1660            self._cb_size = 0.0375 * figure_size
1661        self._ax_text_size = [1.1 * fontscale, 0.9 * fontscale]
1662        self._top_buff_size = 0.30 * fontscale
1663        self._aspect = 1.0
1664
1665        size, axrect, caxrect = self._get_best_layout()
1666
1667        super().__init__(size, axrect, caxrect, zlim, figure, axes, cax)
1668
1669        self._init_image(x_data, y_data, data, x_scale, y_scale, z_scale, zlim, cmap)
1670
1671        self._initfinished = True
1672
1673    def _init_image(
1674        self, x_data, y_data, image_data, x_scale, y_scale, z_scale, zlim, cmap
1675    ):
1676        """Store output of imshow in image variable"""
1677        if z_scale == "log":
1678            norm = matplotlib.colors.LogNorm(zlim[0], zlim[1])
1679        elif z_scale == "linear":
1680            norm = matplotlib.colors.Normalize(zlim[0], zlim[1])
1681        self.image = None
1682        self.cb = None
1683
1684        self.image = self.axes.pcolormesh(
1685            np.array(x_data),
1686            np.array(y_data),
1687            np.array(image_data.T),
1688            norm=norm,
1689            cmap=cmap,
1690            shading=self._shading,
1691        )
1692
1693        self.axes.set_xscale(x_scale)
1694        self.axes.set_yscale(y_scale)
1695        self.cb = self.figure.colorbar(self.image, self.cax)
1696        if z_scale == "linear":
1697            self.cb.formatter.set_scientific(True)
1698            self.cb.formatter.set_powerlimits((-2, 3))
1699            self.cb.update_ticks()
1700