1from collections import defaultdict
2from functools import wraps
3from numbers import Number
4
5import matplotlib
6import matplotlib.pyplot as plt
7import numpy as np
8from more_itertools import always_iterable
9from mpl_toolkits.axes_grid1 import ImageGrid
10from packaging.version import parse as parse_version
11from unyt.exceptions import UnitConversionError
12
13from yt._maintenance.deprecation import issue_deprecation_warning
14from yt.config import ytcfg
15from yt.data_objects.image_array import ImageArray
16from yt.frontends.ytdata.data_structures import YTSpatialPlotDataset
17from yt.funcs import fix_axis, fix_unitary, is_sequence, iter_fields, mylog, obj_length
18from yt.units.unit_object import Unit
19from yt.units.unit_registry import UnitParseError
20from yt.units.yt_array import YTArray, YTQuantity
21from yt.utilities.exceptions import (
22    YTCannotParseUnitDisplayName,
23    YTDataTypeUnsupported,
24    YTInvalidFieldType,
25    YTPlotCallbackError,
26    YTUnitNotRecognized,
27)
28from yt.utilities.math_utils import ortho_find
29from yt.utilities.orientation import Orientation
30
31from .base_plot_types import CallbackWrapper, ImagePlotMPL
32from .fixed_resolution import (
33    FixedResolutionBuffer,
34    OffAxisProjectionFixedResolutionBuffer,
35)
36from .geo_plot_utils import get_mpl_transform
37from .plot_container import (
38    ImagePlotContainer,
39    apply_callback,
40    get_log_minorticks,
41    get_symlog_minorticks,
42    invalidate_data,
43    invalidate_figure,
44    invalidate_plot,
45    linear_transform,
46    log_transform,
47    symlog_transform,
48)
49from .plot_modifications import callback_registry
50
51import sys  # isort: skip
52
53if sys.version_info < (3, 10):
54    # this function is deprecated in more_itertools
55    # because it is superseded by the standard library
56    from more_itertools import zip_equal
57else:
58
59    def zip_equal(*args):
60        # FUTURE: when only Python 3.10+ is supported,
61        # drop this conditional and call the builtin zip
62        # function directly where due
63        return zip(*args, strict=True)
64
65
66MPL_VERSION = parse_version(matplotlib.__version__)
67
68# Some magic for dealing with pyparsing being included or not
69# included in matplotlib (not in gentoo, yes in everything else)
70try:
71    from matplotlib.pyparsing_py3 import ParseFatalException
72except ImportError:
73    from pyparsing import ParseFatalException
74
75
76def get_window_parameters(axis, center, width, ds):
77    width = ds.coordinates.sanitize_width(axis, width, None)
78    center, display_center = ds.coordinates.sanitize_center(center, axis)
79    xax = ds.coordinates.x_axis[axis]
80    yax = ds.coordinates.y_axis[axis]
81    bounds = (
82        display_center[xax] - width[0] / 2,
83        display_center[xax] + width[0] / 2,
84        display_center[yax] - width[1] / 2,
85        display_center[yax] + width[1] / 2,
86    )
87    return (bounds, center, display_center)
88
89
90def get_oblique_window_parameters(normal, center, width, ds, depth=None):
91    display_center, center = ds.coordinates.sanitize_center(center, 4)
92    width = ds.coordinates.sanitize_width(normal, width, depth)
93
94    if len(width) == 2:
95        # Transforming to the cutting plane coordinate system
96        center = (center - ds.domain_left_edge) / ds.domain_width - 0.5
97        (normal, perp1, perp2) = ortho_find(normal)
98        mat = np.transpose(np.column_stack((perp1, perp2, normal)))
99        center = np.dot(mat, center)
100
101    w = tuple(el.in_units("code_length") for el in width)
102    bounds = tuple(((2 * (i % 2)) - 1) * w[i // 2] / 2 for i in range(len(w) * 2))
103
104    return (bounds, center)
105
106
107def get_axes_unit(width, ds):
108    r"""
109    Infers the axes unit names from the input width specification
110    """
111    if ds.no_cgs_equiv_length:
112        return ("code_length",) * 2
113    if is_sequence(width):
114        if isinstance(width[1], str):
115            axes_unit = (width[1], width[1])
116        elif is_sequence(width[1]):
117            axes_unit = (width[0][1], width[1][1])
118        elif isinstance(width[0], YTArray):
119            axes_unit = (str(width[0].units), str(width[1].units))
120        else:
121            axes_unit = None
122    else:
123        if isinstance(width, YTArray):
124            axes_unit = (str(width.units), str(width.units))
125        else:
126            axes_unit = None
127    return axes_unit
128
129
130def validate_mesh_fields(data_source, fields):
131    # this check doesn't make sense for ytdata plot datasets, which
132    # load mesh data as a particle field but nonetheless can still
133    # make plots with it
134    if isinstance(data_source.ds, YTSpatialPlotDataset):
135        return
136    canonical_fields = data_source._determine_fields(fields)
137    invalid_fields = []
138    for field in canonical_fields:
139        finfo = data_source.ds.field_info[field]
140        if finfo.sampling_type == "particle":
141            if not hasattr(data_source.ds, "_sph_ptypes"):
142                pass
143            elif finfo.is_sph_field:
144                continue
145            invalid_fields.append(field)
146
147    if len(invalid_fields) > 0:
148        raise YTInvalidFieldType(invalid_fields)
149
150
151class PlotWindow(ImagePlotContainer):
152    r"""
153    A plotting mechanism based around the concept of a window into a
154    data source. It can have arbitrary fields, each of which will be
155    centered on the same viewpoint, but will have individual zlimits.
156
157    The data and plot are updated separately, and each can be
158    invalidated as the object is modified.
159
160    Data is handled by a FixedResolutionBuffer object.
161
162    Parameters
163    ----------
164
165    data_source :
166        :class:`yt.data_objects.selection_objects.base_objects.YTSelectionContainer2D`
167        This is the source to be pixelized, which can be a projection,
168        slice, or a cutting plane.
169    bounds : sequence of floats
170        Bounds are the min and max in the image plane that we want our
171        image to cover.  It's in the order of (xmin, xmax, ymin, ymax),
172        where the coordinates are all in the appropriate code units.
173    buff_size : sequence of ints
174        The size of the image to generate.
175    antialias : boolean
176        This can be true or false.  It determines whether or not sub-pixel
177        rendering is used during data deposition.
178    window_size : float
179        The size of the window on the longest axis (in units of inches),
180        including the margins but not the colorbar.
181    right_handed : boolean
182        Whether the implicit east vector for the image generated is set to make a right
183        handed coordinate system with a north vector and the normal vector, the
184        direction of the 'window' into the data.
185
186    """
187
188    def __init__(
189        self,
190        data_source,
191        bounds,
192        buff_size=(800, 800),
193        antialias=True,
194        periodic=True,
195        origin="center-window",
196        oblique=False,
197        right_handed=True,
198        window_size=8.0,
199        fields=None,
200        fontsize=18,
201        aspect=None,
202        setup=False,
203    ):
204        self.center = None
205        self._periodic = periodic
206        self.oblique = oblique
207        self._right_handed = right_handed
208        self._equivalencies = defaultdict(lambda: (None, {}))
209        self.buff_size = buff_size
210        self.antialias = antialias
211        self._axes_unit_names = None
212        self._transform = None
213        self._projection = None
214
215        self.aspect = aspect
216        skip = list(FixedResolutionBuffer._exclude_fields) + data_source._key_fields
217
218        fields = list(iter_fields(fields))
219        self.override_fields = list(set(fields).intersection(set(skip)))
220        self.fields = [f for f in fields if f not in skip]
221        super().__init__(data_source, window_size, fontsize)
222
223        self._set_window(bounds)  # this automatically updates the data and plot
224        self.origin = origin
225        if self.data_source.center is not None and not oblique:
226            ax = self.data_source.axis
227            xax = self.ds.coordinates.x_axis[ax]
228            yax = self.ds.coordinates.y_axis[ax]
229            center, display_center = self.ds.coordinates.sanitize_center(
230                self.data_source.center, ax
231            )
232            center = [display_center[xax], display_center[yax]]
233            self.set_center(center)
234
235            axname = self.ds.coordinates.axis_name[ax]
236            transform = self.ds.coordinates.data_transform[axname]
237            projection = self.ds.coordinates.data_projection[axname]
238            self._projection = get_mpl_transform(projection)
239            self._transform = get_mpl_transform(transform)
240
241        for field in self.data_source._determine_fields(self.fields):
242            finfo = self.data_source.ds._get_field_info(*field)
243            if finfo.take_log:
244                self._field_transform[field] = log_transform
245            else:
246                self._field_transform[field] = linear_transform
247
248            log, linthresh = self._log_config[field]
249            if log is not None:
250                self.set_log(field, log, linthresh=linthresh)
251
252            # Access the dictionary to force the key to be created
253            self._units_config[field]
254
255        self.setup_callbacks()
256        self._setup_plots()
257
258    def __iter__(self):
259        for ds in self.ts:
260            mylog.warning("Switching to %s", ds)
261            self._switch_ds(ds)
262            yield self
263
264    def piter(self, *args, **kwargs):
265        for ds in self.ts.piter(*args, **kwargs):
266            self._switch_ds(ds)
267            yield self
268
269    _frb = None
270
271    def frb():
272        doc = "The frb property."
273
274        def fget(self):
275            if self._frb is None or not self._data_valid:
276                self._recreate_frb()
277            return self._frb
278
279        def fset(self, value):
280            self._frb = value
281            self._data_valid = True
282
283        def fdel(self):
284            del self._frb
285            self._frb = None
286            self._data_valid = False
287
288        return locals()
289
290    frb = property(**frb())
291
292    def _recreate_frb(self):
293        old_fields = None
294        # If we are regenerating an frb, we want to know what fields we had before
295        if self._frb is not None:
296            old_fields = list(self._frb.keys())
297            old_units = [str(self._frb[of].units) for of in old_fields]
298
299        # Set the bounds
300        if hasattr(self, "zlim"):
301            bounds = self.xlim + self.ylim + self.zlim
302        else:
303            bounds = self.xlim + self.ylim
304
305        # Generate the FRB
306        self.frb = self._frb_generator(
307            self.data_source,
308            bounds,
309            self.buff_size,
310            self.antialias,
311            periodic=self._periodic,
312        )
313
314        # At this point the frb has the valid bounds, size, aliasing, etc.
315        if old_fields is None:
316            self._frb._get_data_source_fields()
317
318            # New frb, apply default units (if any)
319            for field, field_unit in self._units_config.items():
320                if field_unit is None:
321                    continue
322
323                field_unit = Unit(field_unit, registry=self.ds.unit_registry)
324                is_projected = getattr(self, "projected", False)
325                if is_projected:
326                    # Obtain config
327                    path_length_units = Unit(
328                        ytcfg.get_most_specific(
329                            "plot", *field, "path_length_units", fallback="cm"
330                        ),
331                        registry=self.ds.unit_registry,
332                    )
333                    units = field_unit * path_length_units
334                else:
335                    units = field_unit
336                try:
337                    self.frb[field].convert_to_units(units)
338                except UnitConversionError:
339                    msg = (
340                        "Could not apply default units from configuration.\n"
341                        "Tried converting projected field %s from %s to %s, retaining units %s:\n"
342                        "\tgot units for field: %s"
343                    )
344                    args = [
345                        field,
346                        self.frb[field].units,
347                        units,
348                        field_unit,
349                        units,
350                    ]
351                    if is_projected:
352                        msg += "\n\tgot units for integration length: %s"
353                        args += [path_length_units]
354
355                    msg += "\nCheck your configuration file."
356
357                    mylog.error(msg, *args)
358        else:
359            # Restore the old fields
360            for key, units in zip(old_fields, old_units):
361                self._frb[key]
362                equiv = self._equivalencies[key]
363                if equiv[0] is None:
364                    self._frb[key].convert_to_units(units)
365                else:
366                    self.frb.set_unit(key, units, equiv[0], equiv[1])
367
368        # Restore the override fields
369        for key in self.override_fields:
370            self._frb[key]
371
372    @property
373    def width(self):
374        Wx = self.xlim[1] - self.xlim[0]
375        Wy = self.ylim[1] - self.ylim[0]
376        return (Wx, Wy)
377
378    @property
379    def bounds(self):
380        return self.xlim + self.ylim
381
382    @invalidate_data
383    def zoom(self, factor):
384        r"""This zooms the window by *factor* > 0.
385        - zoom out with *factor* < 1
386        - zoom in with *factor* > 1
387
388        Parameters
389        ----------
390        factor : float
391            multiplier for the current width
392
393        """
394        if factor <= 0:
395            raise ValueError("Only positive zooming factors are meaningful.")
396        Wx, Wy = self.width
397        centerx = self.xlim[0] + Wx * 0.5
398        centery = self.ylim[0] + Wy * 0.5
399        nWx, nWy = Wx / factor, Wy / factor
400        self.xlim = (centerx - nWx * 0.5, centerx + nWx * 0.5)
401        self.ylim = (centery - nWy * 0.5, centery + nWy * 0.5)
402        return self
403
404    @invalidate_data
405    def pan(self, deltas):
406        r"""Pan the image by specifying absolute code unit coordinate deltas.
407
408        Parameters
409        ----------
410        deltas : Two-element sequence of floats, quantities, or (float, unit)
411                 tuples.
412
413            (delta_x, delta_y).  If a unit is not supplied the unit is assumed
414            to be code_length.
415
416        """
417        if len(deltas) != 2:
418            raise RuntimeError(
419                f"The pan function accepts a two-element sequence.\nReceived {deltas}."
420            )
421        if isinstance(deltas[0], Number) and isinstance(deltas[1], Number):
422            deltas = (
423                self.ds.quan(deltas[0], "code_length"),
424                self.ds.quan(deltas[1], "code_length"),
425            )
426        elif isinstance(deltas[0], tuple) and isinstance(deltas[1], tuple):
427            deltas = (
428                self.ds.quan(deltas[0][0], deltas[0][1]),
429                self.ds.quan(deltas[1][0], deltas[1][1]),
430            )
431        elif isinstance(deltas[0], YTQuantity) and isinstance(deltas[1], YTQuantity):
432            pass
433        else:
434            raise RuntimeError(
435                "The arguments of the pan function must be a sequence of floats,\n"
436                "quantities, or (float, unit) tuples. Received %s." % (deltas,)
437            )
438        self.xlim = (self.xlim[0] + deltas[0], self.xlim[1] + deltas[0])
439        self.ylim = (self.ylim[0] + deltas[1], self.ylim[1] + deltas[1])
440        return self
441
442    @invalidate_data
443    def pan_rel(self, deltas):
444        r"""Pan the image by specifying relative deltas, to the FOV.
445
446        Parameters
447        ----------
448        deltas : sequence of floats
449            (delta_x, delta_y) in *relative* code unit coordinates
450
451        """
452        Wx, Wy = self.width
453        self.xlim = (self.xlim[0] + Wx * deltas[0], self.xlim[1] + Wx * deltas[0])
454        self.ylim = (self.ylim[0] + Wy * deltas[1], self.ylim[1] + Wy * deltas[1])
455        return self
456
457    @invalidate_plot
458    def set_unit(self, field, new_unit, equivalency=None, equivalency_kwargs=None):
459        """Sets a new unit for the requested field
460
461        parameters
462        ----------
463        field : string or field tuple
464           The name of the field that is to be changed.
465
466        new_unit : string or Unit object
467           The name of the new unit.
468
469        equivalency : string, optional
470           If set, the equivalency to use to convert the current units to
471           the new requested unit. If None, the unit conversion will be done
472           without an equivalency
473
474        equivalency_kwargs : string, optional
475           Keyword arguments to be passed to the equivalency. Only used if
476           ``equivalency`` is set.
477        """
478        if equivalency_kwargs is None:
479            equivalency_kwargs = {}
480        field = self.data_source._determine_fields(field)[0]
481        for f, u in zip_equal(iter_fields(field), always_iterable(new_unit)):
482            self.frb.set_unit(f, u, equivalency, equivalency_kwargs)
483            self._equivalencies[f] = (equivalency, equivalency_kwargs)
484        return self
485
486    @invalidate_plot
487    def set_origin(self, origin):
488        """Set the plot origin.
489
490        Parameters
491        ----------
492        origin : string or length 1, 2, or 3 sequence.
493           The location of the origin of the plot coordinate system. This
494           is typically represented by a '-' separated string or a tuple of
495           strings. In the first index the y-location is given by 'lower',
496           'upper', or 'center'. The second index is the x-location, given as
497           'left', 'right', or 'center'. Finally, whether the origin is
498           applied in 'domain' space, plot 'window' space or 'native'
499           simulation coordinate system is given. For example, both
500           'upper-right-domain' and ['upper', 'right', 'domain'] place the
501           origin in the upper right hand corner of domain space. If x or y
502           are not given, a value is inferred. For instance, 'left-domain'
503           corresponds to the lower-left hand corner of the simulation domain,
504           'center-domain' corresponds to the center of the simulation domain,
505           or 'center-window' for the center of the plot window. In the event
506           that none of these options place the origin in a desired location,
507           a sequence of tuples and a string specifying the
508           coordinate space can be given. If plain numeric types are input,
509           units of `code_length` are assumed. Further examples:
510
511        ===============================================  ===============================
512        format                                           example
513        ===============================================  ===============================
514        '{space}'                                        'domain'
515        '{xloc}-{space}'                                 'left-window'
516        '{yloc}-{space}'                                 'upper-domain'
517        '{yloc}-{xloc}-{space}'                          'lower-right-window'
518        ('{space}',)                                     ('window',)
519        ('{xloc}', '{space}')                            ('right', 'domain')
520        ('{yloc}', '{space}')                            ('lower', 'window')
521        ('{yloc}', '{xloc}', '{space}')                  ('lower', 'right', 'window')
522        ((yloc, '{unit}'), (xloc, '{unit}'), '{space}')  ((0, 'm'), (.4, 'm'), 'window')
523        (xloc, yloc, '{space}')                          (0.23, 0.5, 'domain')
524        ===============================================  ===============================
525        """
526        self.origin = origin
527        return self
528
529    @invalidate_plot
530    @invalidate_figure
531    def set_mpl_projection(self, mpl_proj):
532        r"""
533        Set the matplotlib projection type with a cartopy transform function
534
535        Given a string or a tuple argument, this will project the data onto
536        the plot axes with the chosen transform function.
537
538        Assumes that the underlying data has a PlateCarree transform type.
539
540        To annotate the plot with coastlines or other annotations,
541        `_setup_plots()` will need to be called after this function
542        to make the axes available for annotation.
543
544        Parameters
545        ----------
546
547        mpl_proj : string or tuple
548           if passed as a string, mpl_proj is the specified projection type,
549           if passed as a tuple, then tuple will take the form of
550           ``("ProjectionType", (args))`` or ``("ProjectionType", (args), {kwargs})``
551           Valid projection type options include:
552           'PlateCarree', 'LambertConformal', 'LabmbertCylindrical',
553           'Mercator', 'Miller', 'Mollweide', 'Orthographic',
554           'Robinson', 'Stereographic', 'TransverseMercator',
555           'InterruptedGoodeHomolosine', 'RotatedPole', 'OGSB',
556           'EuroPP', 'Geostationary', 'Gnomonic', 'NorthPolarStereo',
557           'OSNI', 'SouthPolarStereo', 'AlbersEqualArea',
558           'AzimuthalEquidistant', 'Sinusoidal', 'UTM',
559           'NearsidePerspective', 'LambertAzimuthalEqualArea'
560
561        Examples
562        --------
563
564        This will create a Mollweide projection using Mollweide default values
565        and annotate it with coastlines
566
567        >>> import yt
568        >>> ds = yt.load("")
569        >>> p = yt.SlicePlot(ds, "altitude", "AIRDENS")
570        >>> p.set_mpl_projection("AIRDENS", "Mollweide")
571        >>> p._setup_plots()
572        >>> p.plots["AIRDENS"].axes.coastlines()
573        >>> p.show()
574
575        This will move the PlateCarree central longitude to 90 degrees and
576        annotate with coastlines.
577
578        >>> import yt
579        >>> ds = yt.load("")
580        >>> p = yt.SlicePlot(ds, "altitude", "AIRDENS")
581        >>> p.set_mpl_projection(
582        ...     "AIRDENS", ("PlateCarree", (), {"central_longitude": 90, "globe": None})
583        ... )
584        >>> p._setup_plots()
585        >>> p.plots["AIRDENS"].axes.set_global()
586        >>> p.plots["AIRDENS"].axes.coastlines()
587        >>> p.show()
588
589
590        This will create a RoatatedPole projection with the unrotated pole
591        position at 37.5 degrees latitude and 177.5 degrees longitude by
592        passing them in as args.
593
594
595        >>> import yt
596        >>> ds = yt.load("")
597        >>> p = yt.SlicePlot(ds, "altitude", "AIRDENS")
598        >>> p.set_mpl_projection("RotatedPole", (177.5, 37.5))
599        >>> p._setup_plots()
600        >>> p.plots["AIRDENS"].axes.set_global()
601        >>> p.plots["AIRDENS"].axes.coastlines()
602        >>> p.show()
603
604        This will create a RoatatedPole projection with the unrotated pole
605        position at 37.5 degrees latitude and 177.5 degrees longitude by
606        passing them in as kwargs.
607
608        >>> import yt
609        >>> ds = yt.load("")
610        >>> p = yt.SlicePlot(ds, "altitude", "AIRDENS")
611        >>> p.set_mpl_projection(
612        ...     ("RotatedPole", (), {"pole_latitude": 37.5, "pole_longitude": 177.5})
613        ... )
614        >>> p._setup_plots()
615        >>> p.plots["AIRDENS"].axes.set_global()
616        >>> p.plots["AIRDENS"].axes.coastlines()
617        >>> p.show()
618
619        """
620
621        self._projection = get_mpl_transform(mpl_proj)
622        axname = self.ds.coordinates.axis_name[self.data_source.axis]
623        transform = self.ds.coordinates.data_transform[axname]
624        self._transform = get_mpl_transform(transform)
625        return self
626
627    @invalidate_data
628    def _set_window(self, bounds):
629        """Set the bounds of the plot window.
630        This is normally only called internally, see set_width.
631
632
633        Parameters
634        ----------
635
636        bounds : a four element sequence of floats
637            The x and y bounds, in the format (x0, x1, y0, y1)
638
639        """
640        if self.center is not None:
641            dx = bounds[1] - bounds[0]
642            dy = bounds[3] - bounds[2]
643            self.xlim = (self.center[0] - dx / 2.0, self.center[0] + dx / 2.0)
644            self.ylim = (self.center[1] - dy / 2.0, self.center[1] + dy / 2.0)
645        else:
646            self.xlim = tuple(bounds[0:2])
647            self.ylim = tuple(bounds[2:4])
648            if len(bounds) == 6:
649                self.zlim = tuple(bounds[4:6])
650        mylog.info("xlim = %f %f", self.xlim[0], self.xlim[1])
651        mylog.info("ylim = %f %f", self.ylim[0], self.ylim[1])
652        if hasattr(self, "zlim"):
653            mylog.info("zlim = %f %f", self.zlim[0], self.zlim[1])
654
655    @invalidate_data
656    def set_width(self, width, unit=None):
657        """set the width of the plot window
658
659        parameters
660        ----------
661        width : float, array of floats, (float, unit) tuple, or tuple of
662                (float, unit) tuples.
663
664             Width can have four different formats to support windows with
665             variable x and y widths.  They are:
666
667             ==================================     =======================
668             format                                 example
669             ==================================     =======================
670             (float, string)                        (10,'kpc')
671             ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
672             float                                  0.2
673             (float, float)                         (0.2, 0.3)
674             ==================================     =======================
675
676             For example, (10, 'kpc') requests a plot window that is 10
677             kiloparsecs wide in the x and y directions,
678             ((10,'kpc'),(15,'kpc')) requests a window that is 10 kiloparsecs
679             wide along the x axis and 15 kiloparsecs wide along the y axis.
680             In the other two examples, code units are assumed, for example
681             (0.2, 0.3) requests a plot that has an x width of 0.2 and a y
682             width of 0.3 in code units.  If units are provided the resulting
683             plot axis labels will use the supplied units.
684        unit : str
685             the unit the width has been specified in. If width is a tuple, this
686             argument is ignored. Defaults to code units.
687        """
688        if isinstance(width, Number):
689            if unit is None:
690                width = (width, "code_length")
691            else:
692                width = (width, fix_unitary(unit))
693
694        axes_unit = get_axes_unit(width, self.ds)
695
696        width = self.ds.coordinates.sanitize_width(self.frb.axis, width, None)
697
698        centerx = (self.xlim[1] + self.xlim[0]) / 2.0
699        centery = (self.ylim[1] + self.ylim[0]) / 2.0
700
701        self.xlim = (centerx - width[0] / 2, centerx + width[0] / 2)
702        self.ylim = (centery - width[1] / 2, centery + width[1] / 2)
703
704        if hasattr(self, "zlim"):
705            centerz = (self.zlim[1] + self.zlim[0]) / 2.0
706            mw = self.ds.arr(width).max()
707            self.zlim = (centerz - mw / 2.0, centerz + mw / 2.0)
708
709        self.set_axes_unit(axes_unit)
710
711        return self
712
713    @invalidate_data
714    def set_center(self, new_center, unit="code_length"):
715        """Sets a new center for the plot window
716
717        parameters
718        ----------
719        new_center : two element sequence of floats
720            The coordinates of the new center of the image in the
721            coordinate system defined by the plot axes. If the unit
722            keyword is not specified, the coordinates are assumed to
723            be in code units.
724
725        unit : string
726            The name of the unit new_center is given in.  If new_center is a
727            YTArray or tuple of YTQuantities, this keyword is ignored.
728
729        """
730        error = RuntimeError(
731            "\n"
732            "new_center must be a two-element list or tuple of floats \n"
733            "corresponding to a coordinate in the plot relative to \n"
734            "the plot coordinate system.\n"
735        )
736        if new_center is None:
737            self.center = None
738        elif is_sequence(new_center):
739            if len(new_center) != 2:
740                raise error
741            for el in new_center:
742                if not isinstance(el, Number) and not isinstance(el, YTQuantity):
743                    raise error
744            if isinstance(new_center[0], Number):
745                new_center = [self.ds.quan(c, unit) for c in new_center]
746            self.center = new_center
747        else:
748            raise error
749        self._set_window(self.bounds)
750        return self
751
752    @invalidate_data
753    def set_antialias(self, aa):
754        """Turn antialiasing on or off.
755
756        parameters
757        ----------
758        aa : boolean
759        """
760        self.antialias = aa
761
762    @invalidate_data
763    def set_buff_size(self, size):
764        """Sets a new buffer size for the fixed resolution buffer
765
766        parameters
767        ----------
768        size : int or two element sequence of ints
769            The number of data elements in the buffer on the x and y axes.
770            If a scalar is provided,  then the buffer is assumed to be square.
771        """
772        if is_sequence(size):
773            self.buff_size = size
774        else:
775            self.buff_size = (size, size)
776        return self
777
778    def set_window_size(self, size):
779        """This calls set_figure_size to adjust the size of the plot window."""
780        from yt._maintenance.deprecation import issue_deprecation_warning
781
782        issue_deprecation_warning(
783            "`PlotWindow.set_window_size` is a deprecated alias "
784            "for `PlotWindow.set_figure_size`.",
785            removal="4.1.0",
786        )
787        self.set_figure_size(size)
788        return self
789
790    @invalidate_plot
791    def set_axes_unit(self, unit_name):
792        r"""Set the unit for display on the x and y axes of the image.
793
794        Parameters
795        ----------
796        unit_name : string or two element tuple of strings
797            A unit, available for conversion in the dataset, that the
798            image extents will be displayed in.  If set to None, any previous
799            units will be reset.  If the unit is None, the default is chosen.
800            If unit_name is '1', 'u', or 'unitary', it will not display the
801            units, and only show the axes name. If unit_name is a tuple, the
802            first element is assumed to be the unit for the x axis and the
803            second element the unit for the y axis.
804
805        Raises
806        ------
807        YTUnitNotRecognized
808            If the unit is not known, this will be raised.
809
810        Examples
811        --------
812
813        >>> from yt import load
814        >>> ds = load("IsolatedGalaxy/galaxy0030/galaxy0030")
815        >>> p = ProjectionPlot(ds, "y", "Density")
816        >>> p.set_axes_unit("kpc")
817
818        """
819        # blind except because it could be in conversion_factors or units
820        if unit_name is not None:
821            if isinstance(unit_name, str):
822                unit_name = (unit_name, unit_name)
823            for un in unit_name:
824                try:
825                    self.ds.length_unit.in_units(un)
826                except (UnitConversionError, UnitParseError) as e:
827                    raise YTUnitNotRecognized(un) from e
828        self._axes_unit_names = unit_name
829        return self
830
831    @invalidate_plot
832    def toggle_right_handed(self):
833        self._right_handed = not self._right_handed
834
835    def to_fits_data(self, fields=None, other_keys=None, length_unit=None, **kwargs):
836        r"""Export the fields in this PlotWindow instance
837        to a FITSImageData instance.
838
839        This will export a set of FITS images of either the fields specified
840        or all the fields already in the object.
841
842        Parameters
843        ----------
844        fields : list of strings
845            These fields will be pixelized and output. If "None", the keys of
846            the FRB will be used.
847        other_keys : dictionary, optional
848            A set of header keys and values to write into the FITS header.
849        length_unit : string, optional
850            the length units that the coordinates are written in. The default
851            is to use the default length unit of the dataset.
852        """
853        return self.frb.to_fits_data(
854            fields=fields, other_keys=other_keys, length_unit=length_unit, **kwargs
855        )
856
857
858class PWViewerMPL(PlotWindow):
859    """Viewer using matplotlib as a backend via the WindowPlotMPL."""
860
861    _current_field = None
862    _frb_generator = None
863    _plot_type = None
864    _data_valid = False
865
866    def __init__(self, *args, **kwargs):
867        if self._frb_generator is None:
868            self._frb_generator = kwargs.pop("frb_generator")
869        if self._plot_type is None:
870            self._plot_type = kwargs.pop("plot_type")
871        self._splat_color = kwargs.pop("splat_color", None)
872        PlotWindow.__init__(self, *args, **kwargs)
873
874    def _setup_origin(self):
875        origin = self.origin
876        axis_index = self.data_source.axis
877        xc = None
878        yc = None
879
880        if isinstance(origin, str):
881            origin = tuple(origin.split("-"))[:3]
882        if 1 == len(origin):
883            origin = ("lower", "left") + origin
884        elif 2 == len(origin) and origin[0] in {"left", "right", "center"}:
885            o0map = {"left": "lower", "right": "upper", "center": "center"}
886            origin = (o0map[origin[0]],) + origin
887        elif 2 == len(origin) and origin[0] in {"lower", "upper", "center"}:
888            origin = (origin[0], "center", origin[-1])
889        elif 3 == len(origin) and isinstance(origin[0], (int, float)):
890            xc = self.ds.quan(origin[0], "code_length")
891            yc = self.ds.quan(origin[1], "code_length")
892        elif 3 == len(origin) and isinstance(origin[0], tuple):
893            xc = self.ds.quan(origin[0][0], origin[0][1])
894            yc = self.ds.quan(origin[1][0], origin[0][1])
895
896        assert origin[-1] in ["window", "domain", "native"]
897
898        if origin[2] == "window":
899            xllim, xrlim = self.xlim
900            yllim, yrlim = self.ylim
901        elif origin[2] == "domain":
902            xax = self.ds.coordinates.x_axis[axis_index]
903            yax = self.ds.coordinates.y_axis[axis_index]
904            xllim = self.ds.domain_left_edge[xax]
905            xrlim = self.ds.domain_right_edge[xax]
906            yllim = self.ds.domain_left_edge[yax]
907            yrlim = self.ds.domain_right_edge[yax]
908        elif origin[2] == "native":
909            return (self.ds.quan(0.0, "code_length"), self.ds.quan(0.0, "code_length"))
910        else:
911            mylog.warning("origin = %s", origin)
912            msg = (
913                'origin keyword "{}" not recognized, must declare "domain" '
914                'or "center" as the last term in origin.'
915            ).format(self.origin)
916            raise RuntimeError(msg)
917        if xc is None and yc is None:
918            if origin[0] == "lower":
919                yc = yllim
920            elif origin[0] == "upper":
921                yc = yrlim
922            elif origin[0] == "center":
923                yc = (yllim + yrlim) / 2.0
924            else:
925                mylog.warning("origin = %s", origin)
926                msg = (
927                    'origin keyword "{0}" not recognized, must declare "lower" '
928                    '"upper" or "center" as the first term in origin.'
929                )
930                msg = msg.format(self.origin)
931                raise RuntimeError(msg)
932
933            if origin[1] == "left":
934                xc = xllim
935            elif origin[1] == "right":
936                xc = xrlim
937            elif origin[1] == "center":
938                xc = (xllim + xrlim) / 2.0
939            else:
940                mylog.warning("origin = %s", origin)
941                msg = (
942                    'origin keyword "{0}" not recognized, must declare "left" '
943                    '"right" or "center" as the second term in origin.'
944                )
945                msg = msg.format(self.origin)
946                raise RuntimeError(msg)
947
948        x_in_bounds = xc >= xllim and xc <= xrlim
949        y_in_bounds = yc >= yllim and yc <= yrlim
950
951        if not x_in_bounds and not y_in_bounds:
952            msg = "origin inputs not in bounds of specified coordinate system domain."
953            msg = msg.format(self.origin)
954            raise RuntimeError(msg)
955
956        return xc, yc
957
958    def _setup_plots(self):
959        from matplotlib.mathtext import MathTextParser
960
961        if self._plot_valid:
962            return
963        if not self._data_valid:
964            self._recreate_frb()
965            self._data_valid = True
966        self._colorbar_valid = True
967        for f in list(set(self.data_source._determine_fields(self.fields))):
968            axis_index = self.data_source.axis
969
970            xc, yc = self._setup_origin()
971            if self.ds.unit_system._code_flag or self.ds.no_cgs_equiv_length:
972                # this should happen only if the dataset was initialized with
973                # argument unit_system="code" or if it's set to have no CGS
974                # equivalent.  This only needs to happen here in the specific
975                # case that we're doing a computationally intense operation
976                # like using cartopy, but it prevents crashes in that case.
977                (unit_x, unit_y) = ("code_length", "code_length")
978            elif self._axes_unit_names is None:
979                unit = self.ds.get_smallest_appropriate_unit(
980                    self.xlim[1] - self.xlim[0]
981                )
982                (unit_x, unit_y) = (unit, unit)
983            else:
984                (unit_x, unit_y) = self._axes_unit_names
985
986            # For some plots we may set aspect by hand, such as for spectral cube data.
987            # This will likely be replaced at some point by the coordinate handler
988            # setting plot aspect.
989            if self.aspect is None:
990                self.aspect = float(
991                    (self.ds.quan(1.0, unit_y) / self.ds.quan(1.0, unit_x)).in_cgs()
992                )
993            extentx = [(self.xlim[i] - xc).in_units(unit_x) for i in (0, 1)]
994            extenty = [(self.ylim[i] - yc).in_units(unit_y) for i in (0, 1)]
995
996            extent = extentx + extenty
997
998            if f in self.plots.keys():
999                zlim = (self.plots[f].zmin, self.plots[f].zmax)
1000            else:
1001                zlim = (None, None)
1002
1003            image = self.frb[f]
1004            if self._field_transform[f] == log_transform:
1005                msg = None
1006                use_symlog = False
1007                if zlim != (None, None):
1008                    pass
1009                elif np.nanmax(image) == np.nanmin(image):
1010                    msg = f"Plotting {f}: All values = {np.nanmax(image)}"
1011                elif np.nanmax(image) <= 0:
1012                    msg = (
1013                        f"Plotting {f}: All negative values. Max = {np.nanmax(image)}."
1014                    )
1015                    use_symlog = True
1016                elif not np.any(np.isfinite(image)):
1017                    msg = f"Plotting {f}: All values = NaN."
1018                elif np.nanmax(image) > 0.0 and np.nanmin(image) < 0:
1019                    msg = (
1020                        f"Plotting {f}: Both positive and negative values. "
1021                        f"Min = {np.nanmin(image)}, Max = {np.nanmax(image)}."
1022                    )
1023                    use_symlog = True
1024                elif np.nanmax(image) > 0.0 and np.nanmin(image) == 0:
1025                    # normally, a LogNorm scaling would still be OK here because
1026                    # LogNorm will mask 0 values when calculating vmin. But
1027                    # due to a bug in matplotlib's imshow, if the data range
1028                    # spans many orders of magnitude while containing zero points
1029                    # vmin can get rescaled to 0, resulting in an error when the image
1030                    # gets drawn. So here we switch to symlog to avoid that until
1031                    # a fix is in -- see PR #3161 and linked issue.
1032                    cutoff_sigdigs = 15
1033                    if (
1034                        np.log10(np.nanmax(image[np.isfinite(image)]))
1035                        - np.log10(np.nanmin(image[image > 0]))
1036                        > cutoff_sigdigs
1037                    ):
1038                        msg = f"Plotting {f}: Wide range and zeros."
1039                        use_symlog = True
1040                if msg is not None:
1041                    mylog.warning(msg)
1042                    if use_symlog:
1043                        mylog.warning("Switching to symlog colorbar scaling.")
1044                        self._field_transform[f] = symlog_transform
1045                        self._field_transform[f].func = None
1046                    else:
1047                        mylog.warning("Switching to linear colorbar scaling.")
1048                        self._field_transform[f] = linear_transform
1049
1050            font_size = self._font_properties.get_size()
1051
1052            fig = None
1053            axes = None
1054            cax = None
1055            draw_colorbar = True
1056            draw_axes = True
1057            draw_frame = draw_axes
1058            if f in self.plots:
1059                draw_colorbar = self.plots[f]._draw_colorbar
1060                draw_axes = self.plots[f]._draw_axes
1061                draw_frame = self.plots[f]._draw_frame
1062                if self.plots[f].figure is not None:
1063                    fig = self.plots[f].figure
1064                    axes = self.plots[f].axes
1065                    cax = self.plots[f].cax
1066
1067            # This is for splatting particle positions with a single
1068            # color instead of a colormap
1069            if self._splat_color is not None:
1070                # make image a rgba array, using the splat color
1071                greyscale_image = self.frb[f]
1072                ia = np.zeros((greyscale_image.shape[0], greyscale_image.shape[1], 4))
1073                ia[:, :, 3] = 0.0  # set alpha to 0.0
1074                locs = greyscale_image > 0.0
1075                to_rgba = matplotlib.colors.colorConverter.to_rgba
1076                color_tuple = to_rgba(self._splat_color)
1077                ia[locs] = color_tuple
1078                ia = ImageArray(ia)
1079            else:
1080                ia = image
1081            self.plots[f] = WindowPlotMPL(
1082                ia,
1083                self._field_transform[f].name,
1084                self._field_transform[f].func,
1085                self._colormap_config[f],
1086                extent,
1087                zlim,
1088                self.figure_size,
1089                font_size,
1090                self.aspect,
1091                fig,
1092                axes,
1093                cax,
1094                self._projection,
1095                self._transform,
1096            )
1097
1098            if not self._right_handed:
1099                ax = self.plots[f].axes
1100                ax.invert_xaxis()
1101
1102            axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
1103
1104            if self.oblique:
1105                labels = [
1106                    r"$\rm{Image\ x" + axes_unit_labels[0] + "}$",
1107                    r"$\rm{Image\ y" + axes_unit_labels[1] + "}$",
1108                ]
1109            else:
1110                coordinates = self.ds.coordinates
1111                axis_names = coordinates.image_axis_name[axis_index]
1112                xax = coordinates.x_axis[axis_index]
1113                yax = coordinates.y_axis[axis_index]
1114
1115                if hasattr(coordinates, "axis_default_unit_name"):
1116                    axes_unit_labels = [
1117                        coordinates.axis_default_unit_name[xax],
1118                        coordinates.axis_default_unit_name[yax],
1119                    ]
1120                labels = [
1121                    r"$\rm{" + axis_names[0] + axes_unit_labels[0] + r"}$",
1122                    r"$\rm{" + axis_names[1] + axes_unit_labels[1] + r"}$",
1123                ]
1124
1125                if hasattr(coordinates, "axis_field"):
1126                    if xax in coordinates.axis_field:
1127                        xmin, xmax = coordinates.axis_field[xax](
1128                            0, self.xlim, self.ylim
1129                        )
1130                    else:
1131                        xmin, xmax = (float(x) for x in extentx)
1132                    if yax in coordinates.axis_field:
1133                        ymin, ymax = coordinates.axis_field[yax](
1134                            1, self.xlim, self.ylim
1135                        )
1136                    else:
1137                        ymin, ymax = (float(y) for y in extenty)
1138                    self.plots[f].image.set_extent((xmin, xmax, ymin, ymax))
1139                    self.plots[f].axes.set_aspect("auto")
1140
1141            x_label, y_label, colorbar_label = self._get_axes_labels(f)
1142
1143            if x_label is not None:
1144                labels[0] = x_label
1145            if y_label is not None:
1146                labels[1] = y_label
1147
1148            self.plots[f].axes.set_xlabel(labels[0])
1149            self.plots[f].axes.set_ylabel(labels[1])
1150
1151            color = self._background_color[f]
1152
1153            self.plots[f].axes.set_facecolor(color)
1154
1155            # Determine the units of the data
1156            units = Unit(self.frb[f].units, registry=self.ds.unit_registry)
1157            units = units.latex_representation()
1158
1159            if colorbar_label is None:
1160                colorbar_label = image.info["label"]
1161                if hasattr(self, "projected"):
1162                    colorbar_label = "$\\rm{Projected }$ %s" % colorbar_label
1163                if units is None or units == "":
1164                    pass
1165                else:
1166                    colorbar_label += r"$\ \ \left(" + units + r"\right)$"
1167
1168            parser = MathTextParser("Agg")
1169            try:
1170                parser.parse(colorbar_label)
1171            except ParseFatalException as err:
1172                raise YTCannotParseUnitDisplayName(f, colorbar_label, str(err))
1173
1174            self.plots[f].cb.set_label(colorbar_label)
1175
1176            # x-y axes minorticks
1177            if f not in self._minorticks:
1178                self._minorticks[f] = True
1179            if self._minorticks[f]:
1180                self.plots[f].axes.minorticks_on()
1181            else:
1182                self.plots[f].axes.minorticks_off()
1183
1184            # colorbar minorticks
1185            if f not in self._cbar_minorticks:
1186                self._cbar_minorticks[f] = True
1187
1188            if self._cbar_minorticks[f]:
1189                vmin = np.float64(self.plots[f].cb.norm.vmin)
1190                vmax = np.float64(self.plots[f].cb.norm.vmax)
1191
1192                if self._field_transform[f] == linear_transform:
1193                    self.plots[f].cax.minorticks_on()
1194
1195                elif self._field_transform[f] == symlog_transform:
1196                    flinthresh = 10 ** np.floor(
1197                        np.log10(self.plots[f].cb.norm.linthresh)
1198                    )
1199                    mticks = self.plots[f].image.norm(
1200                        get_symlog_minorticks(flinthresh, vmin, vmax)
1201                    )
1202                    self.plots[f].cax.yaxis.set_ticks(mticks, minor=True)
1203
1204                elif self._field_transform[f] == log_transform:
1205                    if MPL_VERSION >= parse_version("3.0.0"):
1206                        self.plots[f].cax.minorticks_on()
1207                        self.plots[f].cax.xaxis.set_visible(False)
1208                    else:
1209                        mticks = self.plots[f].image.norm(
1210                            get_log_minorticks(vmin, vmax)
1211                        )
1212                        self.plots[f].cax.yaxis.set_ticks(mticks, minor=True)
1213
1214                else:
1215                    mylog.error(
1216                        "Unable to draw cbar minorticks for field "
1217                        "%s with transform %s ",
1218                        f,
1219                        self._field_transform[f],
1220                    )
1221                    self._cbar_minorticks[f] = False
1222
1223            if not self._cbar_minorticks[f]:
1224                self.plots[f].cax.minorticks_off()
1225
1226            if not draw_axes:
1227                self.plots[f]._toggle_axes(draw_axes, draw_frame)
1228
1229            if not draw_colorbar:
1230                self.plots[f]._toggle_colorbar(draw_colorbar)
1231
1232        self._set_font_properties()
1233        self.run_callbacks()
1234        self._plot_valid = True
1235
1236    def setup_callbacks(self):
1237        for key in callback_registry:
1238            ignored = ["PlotCallback"]
1239            if self._plot_type.startswith("OffAxis"):
1240                ignored += [
1241                    "ParticleCallback",
1242                    "ClumpContourCallback",
1243                    "GridBoundaryCallback",
1244                ]
1245            if self._plot_type == "OffAxisProjection":
1246                ignored += [
1247                    "VelocityCallback",
1248                    "MagFieldCallback",
1249                    "QuiverCallback",
1250                    "CuttingQuiverCallback",
1251                    "StreamlineCallback",
1252                ]
1253            if self._plot_type == "Particle":
1254                ignored += [
1255                    "HopCirclesCallback",
1256                    "HopParticleCallback",
1257                    "ClumpContourCallback",
1258                    "GridBoundaryCallback",
1259                    "VelocityCallback",
1260                    "MagFieldCallback",
1261                    "QuiverCallback",
1262                    "CuttingQuiverCallback",
1263                    "StreamlineCallback",
1264                    "ContourCallback",
1265                ]
1266            if key in ignored:
1267                continue
1268            cbname = callback_registry[key]._type_name
1269
1270            # We need to wrap to create a closure so that
1271            # CallbackMaker is bound to the wrapped method.
1272            def closure():
1273                CallbackMaker = callback_registry[key]
1274
1275                @wraps(CallbackMaker)
1276                def method(*args, **kwargs):
1277                    # We need to also do it here as "invalidate_plot"
1278                    # and "apply_callback" require the functions'
1279                    # __name__ in order to work properly
1280                    @wraps(CallbackMaker)
1281                    def cb(self, *a, **kwa):
1282                        # We construct the callback method
1283                        # skipping self
1284                        return CallbackMaker(*a, **kwa)
1285
1286                    # Create callback
1287                    cb = invalidate_plot(apply_callback(cb))
1288
1289                    return cb(self, *args, **kwargs)
1290
1291                return method
1292
1293            self.__dict__["annotate_" + cbname] = closure()
1294
1295    def annotate_clear(self, index=None):
1296        """
1297        Clear callbacks from the plot.  If index is not set, clear all
1298        callbacks.  If index is set, clear that index (ie 0 is the first one
1299        created, 1 is the 2nd one created, -1 is the last one created, etc.)
1300
1301        .. note::
1302
1303            Deprecated in favor of `clear_annotations`.
1304
1305        See Also
1306        --------
1307        :py:meth:`yt.visualization.plot_window.PWViewerMPL.clear_annotations`
1308        """
1309        issue_deprecation_warning(
1310            "`annotate_clear` has been deprecated "
1311            "in favor of `clear_annotations`. Using `clear_annotations`.",
1312            since="4.0.0",
1313            removal="4.1.0",
1314        )
1315        self.clear_annotations(index=index)
1316
1317    @invalidate_plot
1318    def clear_annotations(self, index=None):
1319        """
1320        Clear callbacks from the plot.  If index is not set, clear all
1321        callbacks.  If index is set, clear that index (ie 0 is the first one
1322        created, 1 is the 2nd one created, -1 is the last one created, etc.)
1323        """
1324        if index is None:
1325            self._callbacks = []
1326        else:
1327            del self._callbacks[index]
1328        self.setup_callbacks()
1329        return self
1330
1331    def list_annotations(self):
1332        """
1333        List the current callbacks for the plot, along with their index.  This
1334        index can be used with `clear_annotations` to remove a callback from the
1335        current plot.
1336        """
1337        for i, cb in enumerate(self._callbacks):
1338            print(i, cb)
1339
1340    def run_callbacks(self):
1341        for f in self.fields:
1342            keys = self.frb.keys()
1343            for name, (args, kwargs) in self._callbacks:
1344                cbw = CallbackWrapper(
1345                    self,
1346                    self.plots[f],
1347                    self.frb,
1348                    f,
1349                    self._font_properties,
1350                    self._font_color,
1351                )
1352                CallbackMaker = callback_registry[name]
1353                callback = CallbackMaker(*args[1:], **kwargs)
1354                try:
1355                    callback(cbw)
1356                except YTDataTypeUnsupported as e:
1357                    raise e
1358                except Exception as e:
1359                    raise YTPlotCallbackError(callback._type_name) from e
1360            for key in self.frb.keys():
1361                if key not in keys:
1362                    del self.frb[key]
1363
1364    def export_to_mpl_figure(
1365        self,
1366        nrows_ncols,
1367        axes_pad=1.0,
1368        label_mode="L",
1369        cbar_location="right",
1370        cbar_size="5%",
1371        cbar_mode="each",
1372        cbar_pad="0%",
1373    ):
1374        r"""
1375        Creates a matplotlib figure object with the specified axes arrangement,
1376        nrows_ncols, and maps the underlying figures to the matplotlib axes.
1377        Note that all of these parameters are fed directly to the matplotlib ImageGrid
1378        class to create the new figure layout.
1379
1380        Parameters
1381        ----------
1382
1383        nrows_ncols : tuple
1384           the number of rows and columns of the axis grid (e.g., nrows_ncols=(2,2,))
1385        axes_pad : float
1386           padding between axes in inches
1387        label_mode : one of "L", "1", "all"
1388           arrangement of axes that are labeled
1389        cbar_location : one of "left", "right", "bottom", "top"
1390           where to place the colorbar
1391        cbar_size : string (percentage)
1392           scaling of the colorbar (e.g., "5%")
1393        cbar_mode : one of "each", "single", "edge", None
1394           how to represent the colorbar
1395        cbar_pad : string (percentage)
1396           padding between the axis and colorbar (e.g. "5%")
1397
1398        Returns
1399        -------
1400
1401        The return is a matplotlib figure object.
1402
1403        Examples
1404        --------
1405
1406        >>> import yt
1407        >>> ds = yt.load_sample("IsolatedGalaxy")
1408        >>> fields = ["density", "velocity_x", "velocity_y", "velocity_magnitude"]
1409        >>> p = yt.SlicePlot(ds, "z", fields)
1410        >>> p.set_log("velocity_x", False)
1411        >>> p.set_log("velocity_y", False)
1412        >>> fig = p.export_to_mpl_figure((2, 2))
1413        >>> fig.tight_layout()
1414        >>> fig.savefig("test.png")
1415
1416        """
1417
1418        fig = plt.figure()
1419        grid = ImageGrid(
1420            fig,
1421            111,
1422            nrows_ncols=nrows_ncols,
1423            axes_pad=axes_pad,
1424            label_mode=label_mode,
1425            cbar_location=cbar_location,
1426            cbar_size=cbar_size,
1427            cbar_mode=cbar_mode,
1428            cbar_pad=cbar_pad,
1429        )
1430
1431        fields = self.fields
1432        if len(fields) > len(grid):
1433            raise IndexError("not enough axes for the number of fields")
1434
1435        for i, f in enumerate(self.fields):
1436            plot = self.plots[f]
1437            plot.figure = fig
1438            plot.axes = grid[i].axes
1439            plot.cax = grid.cbar_axes[i]
1440
1441        self._setup_plots()
1442
1443        return fig
1444
1445
1446class AxisAlignedSlicePlot(PWViewerMPL):
1447    r"""Creates a slice plot from a dataset
1448
1449    Given a ds object, an axis to slice along, and a field name
1450    string, this will return a PWViewerMPL object containing
1451    the plot.
1452
1453    The plot can be updated using one of the many helper functions
1454    defined in PlotWindow.
1455
1456    Parameters
1457    ----------
1458    ds : `Dataset`
1459         This is the dataset object corresponding to the
1460         simulation output to be plotted.
1461    axis : int or one of 'x', 'y', 'z'
1462         An int corresponding to the axis to slice along (0=x, 1=y, 2=z)
1463         or the axis name itself
1464    fields : string
1465         The name of the field(s) to be plotted.
1466    center : A sequence of floats, a string, or a tuple.
1467         The coordinate of the center of the image. If set to 'c', 'center' or
1468         left blank, the plot is centered on the middle of the domain. If set to
1469         'max' or 'm', the center will be located at the maximum of the
1470         ('gas', 'density') field. Centering on the max or min of a specific
1471         field is supported by providing a tuple such as ("min","temperature") or
1472         ("max","dark_matter_density"). Units can be specified by passing in *center*
1473         as a tuple containing a coordinate and string unit name or by passing
1474         in a YTArray. If a list or unitless array is supplied, code units are
1475         assumed.
1476    width : tuple or a float.
1477         Width can have four different formats to support windows with variable
1478         x and y widths.  They are:
1479
1480         ==================================     =======================
1481         format                                 example
1482         ==================================     =======================
1483         (float, string)                        (10,'kpc')
1484         ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
1485         float                                  0.2
1486         (float, float)                         (0.2, 0.3)
1487         ==================================     =======================
1488
1489         For example, (10, 'kpc') requests a plot window that is 10 kiloparsecs
1490         wide in the x and y directions, ((10,'kpc'),(15,'kpc')) requests a
1491         window that is 10 kiloparsecs wide along the x axis and 15
1492         kiloparsecs wide along the y axis.  In the other two examples, code
1493         units are assumed, for example (0.2, 0.3) requests a plot that has an
1494         x width of 0.2 and a y width of 0.3 in code units.  If units are
1495         provided the resulting plot axis labels will use the supplied units.
1496    origin : string or length 1, 2, or 3 sequence.
1497         The location of the origin of the plot coordinate system. This
1498         is typically represented by a '-' separated string or a tuple of
1499         strings. In the first index the y-location is given by 'lower',
1500         'upper', or 'center'. The second index is the x-location, given as
1501         'left', 'right', or 'center'. Finally, whether the origin is
1502         applied in 'domain' space, plot 'window' space or 'native'
1503         simulation coordinate system is given. For example, both
1504         'upper-right-domain' and ['upper', 'right', 'domain'] place the
1505         origin in the upper right hand corner of domain space. If x or y
1506         are not given, a value is inferred. For instance, 'left-domain'
1507         corresponds to the lower-left hand corner of the simulation domain,
1508         'center-domain' corresponds to the center of the simulation domain,
1509         or 'center-window' for the center of the plot window. In the event
1510         that none of these options place the origin in a desired location,
1511         a sequence of tuples and a string specifying the
1512         coordinate space can be given. If plain numeric types are input,
1513         units of `code_length` are assumed. Further examples:
1514
1515         =============================================== ===============================
1516         format                                          example
1517         =============================================== ===============================
1518         '{space}'                                       'domain'
1519         '{xloc}-{space}'                                'left-window'
1520         '{yloc}-{space}'                                'upper-domain'
1521         '{yloc}-{xloc}-{space}'                         'lower-right-window'
1522         ('{space}',)                                    ('window',)
1523         ('{xloc}', '{space}')                           ('right', 'domain')
1524         ('{yloc}', '{space}')                           ('lower', 'window')
1525         ('{yloc}', '{xloc}', '{space}')                 ('lower', 'right', 'window')
1526         ((yloc, '{unit}'), (xloc, '{unit}'), '{space}') ((0, 'm'), (.4, 'm'), 'window')
1527         (xloc, yloc, '{space}')                         (0.23, 0.5, 'domain')
1528         =============================================== ===============================
1529    axes_unit : string
1530         The name of the unit for the tick labels on the x and y axes.
1531         Defaults to None, which automatically picks an appropriate unit.
1532         If axes_unit is '1', 'u', or 'unitary', it will not display the
1533         units, and only show the axes name.
1534    right_handed : boolean
1535         Whether the implicit east vector for the image generated is set to make a right
1536         handed coordinate system with a normal vector, the direction of the
1537         'window' into the data.
1538    fontsize : integer
1539         The size of the fonts for the axis, colorbar, and tick labels.
1540    field_parameters : dictionary
1541         A dictionary of field parameters than can be accessed by derived
1542         fields.
1543    data_source: YTSelectionContainer object
1544         Object to be used for data selection. Defaults to ds.all_data(), a
1545         region covering the full domain
1546    buff_size: length 2 sequence
1547         Size of the buffer to use for the image, i.e. the number of resolution elements
1548         used.  Effectively sets a resolution limit to the image if buff_size is
1549         smaller than the finest gridding.
1550
1551    Examples
1552    --------
1553
1554    This will save an image in the file 'sliceplot_Density.png'
1555
1556    >>> from yt import load
1557    >>> ds = load("IsolatedGalaxy/galaxy0030/galaxy0030")
1558    >>> p = SlicePlot(ds, 2, "density", "c", (20, "kpc"))
1559    >>> p.save("sliceplot")
1560
1561    """
1562    _plot_type = "Slice"
1563    _frb_generator = FixedResolutionBuffer
1564
1565    def __init__(
1566        self,
1567        ds,
1568        axis,
1569        fields,
1570        center="c",
1571        width=None,
1572        axes_unit=None,
1573        origin="center-window",
1574        right_handed=True,
1575        fontsize=18,
1576        field_parameters=None,
1577        window_size=8.0,
1578        aspect=None,
1579        data_source=None,
1580        buff_size=(800, 800),
1581    ):
1582        # this will handle time series data and controllers
1583        axis = fix_axis(axis, ds)
1584        (bounds, center, display_center) = get_window_parameters(
1585            axis, center, width, ds
1586        )
1587        if field_parameters is None:
1588            field_parameters = {}
1589
1590        if ds.geometry in (
1591            "spherical",
1592            "cylindrical",
1593            "geographic",
1594            "internal_geographic",
1595        ):
1596            mylog.info("Setting origin='native' for %s geometry.", ds.geometry)
1597            origin = "native"
1598
1599        if isinstance(ds, YTSpatialPlotDataset):
1600            slc = ds.all_data()
1601            slc.axis = axis
1602            if slc.axis != ds.parameters["axis"]:
1603                raise RuntimeError(f"Original slice axis is {ds.parameters['axis']}.")
1604        else:
1605            slc = ds.slice(
1606                axis,
1607                center[axis],
1608                field_parameters=field_parameters,
1609                center=center,
1610                data_source=data_source,
1611            )
1612            slc.get_data(fields)
1613        validate_mesh_fields(slc, fields)
1614        PWViewerMPL.__init__(
1615            self,
1616            slc,
1617            bounds,
1618            origin=origin,
1619            fontsize=fontsize,
1620            fields=fields,
1621            window_size=window_size,
1622            aspect=aspect,
1623            right_handed=right_handed,
1624            buff_size=buff_size,
1625        )
1626        if axes_unit is None:
1627            axes_unit = get_axes_unit(width, ds)
1628        self.set_axes_unit(axes_unit)
1629
1630
1631class ProjectionPlot(PWViewerMPL):
1632    r"""Creates a projection plot from a dataset
1633
1634    Given a ds object, an axis to project along, and a field name
1635    string, this will return a PWViewerMPL object containing
1636    the plot.
1637
1638    The plot can be updated using one of the many helper functions
1639    defined in PlotWindow.
1640
1641    Parameters
1642    ----------
1643    ds : `Dataset`
1644        This is the dataset object corresponding to the
1645        simulation output to be plotted.
1646    axis : int or one of 'x', 'y', 'z'
1647         An int corresponding to the axis to slice along (0=x, 1=y, 2=z)
1648         or the axis name itself
1649    fields : string
1650         The name of the field(s) to be plotted.
1651    center : A sequence of floats, a string, or a tuple.
1652         The coordinate of the center of the image. If set to 'c', 'center' or
1653         left blank, the plot is centered on the middle of the domain. If set to
1654         'max' or 'm', the center will be located at the maximum of the
1655         ('gas', 'density') field. Centering on the max or min of a specific
1656         field is supported by providing a tuple such as ("min","temperature") or
1657         ("max","dark_matter_density"). Units can be specified by passing in *center*
1658         as a tuple containing a coordinate and string unit name or by passing
1659         in a YTArray. If a list or unitless array is supplied, code units are
1660         assumed.
1661    width : tuple or a float.
1662         Width can have four different formats to support windows with variable
1663         x and y widths.  They are:
1664
1665         ==================================     =======================
1666         format                                 example
1667         ==================================     =======================
1668         (float, string)                        (10,'kpc')
1669         ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
1670         float                                  0.2
1671         (float, float)                         (0.2, 0.3)
1672         ==================================     =======================
1673
1674         For example, (10, 'kpc') requests a plot window that is 10 kiloparsecs
1675         wide in the x and y directions, ((10,'kpc'),(15,'kpc')) requests a
1676         window that is 10 kiloparsecs wide along the x axis and 15
1677         kiloparsecs wide along the y axis.  In the other two examples, code
1678         units are assumed, for example (0.2, 0.3) requests a plot that has an
1679         x width of 0.2 and a y width of 0.3 in code units.  If units are
1680         provided the resulting plot axis labels will use the supplied units.
1681    axes_unit : string
1682         The name of the unit for the tick labels on the x and y axes.
1683         Defaults to None, which automatically picks an appropriate unit.
1684         If axes_unit is '1', 'u', or 'unitary', it will not display the
1685         units, and only show the axes name.
1686    origin : string or length 1, 2, or 3 sequence.
1687         The location of the origin of the plot coordinate system. This
1688         is typically represented by a '-' separated string or a tuple of
1689         strings. In the first index the y-location is given by 'lower',
1690         'upper', or 'center'. The second index is the x-location, given as
1691         'left', 'right', or 'center'. Finally, whether the origin is
1692         applied in 'domain' space, plot 'window' space or 'native'
1693         simulation coordinate system is given. For example, both
1694         'upper-right-domain' and ['upper', 'right', 'domain'] place the
1695         origin in the upper right hand corner of domain space. If x or y
1696         are not given, a value is inferred. For instance, 'left-domain'
1697         corresponds to the lower-left hand corner of the simulation domain,
1698         'center-domain' corresponds to the center of the simulation domain,
1699         or 'center-window' for the center of the plot window. In the event
1700         that none of these options place the origin in a desired location,
1701         a sequence of tuples and a string specifying the
1702         coordinate space can be given. If plain numeric types are input,
1703         units of `code_length` are assumed. Further examples:
1704
1705         =============================================== ===============================
1706         format                                          example
1707         =============================================== ===============================
1708         '{space}'                                       'domain'
1709         '{xloc}-{space}'                                'left-window'
1710         '{yloc}-{space}'                                'upper-domain'
1711         '{yloc}-{xloc}-{space}'                         'lower-right-window'
1712         ('{space}',)                                    ('window',)
1713         ('{xloc}', '{space}')                           ('right', 'domain')
1714         ('{yloc}', '{space}')                           ('lower', 'window')
1715         ('{yloc}', '{xloc}', '{space}')                 ('lower', 'right', 'window')
1716         ((yloc, '{unit}'), (xloc, '{unit}'), '{space}') ((0, 'm'), (.4, 'm'), 'window')
1717         (xloc, yloc, '{space}')                            (0.23, 0.5, 'domain')
1718         =============================================== ===============================
1719
1720    right_handed : boolean
1721         Whether the implicit east vector for the image generated is set to make a right
1722         handed coordinate system with the direction of the
1723         'window' into the data.
1724    data_source : YTSelectionContainer Object
1725         Object to be used for data selection.  Defaults to a region covering
1726         the entire simulation.
1727    weight_field : string
1728         The name of the weighting field.  Set to None for no weight.
1729    max_level: int
1730         The maximum level to project to.
1731    fontsize : integer
1732         The size of the fonts for the axis, colorbar, and tick labels.
1733    method : string
1734         The method of projection.  Valid methods are:
1735
1736         "integrate" with no weight_field specified : integrate the requested
1737         field along the line of sight.
1738
1739         "integrate" with a weight_field specified : weight the requested
1740         field by the weighting field and integrate along the line of sight.
1741
1742         "mip" : pick out the maximum value of the field in the line of sight.
1743
1744         "sum" : This method is the same as integrate, except that it does not
1745         multiply by a path length when performing the integration, and is
1746         just a straight summation of the field along the given axis. WARNING:
1747         This should only be used for uniform resolution grid datasets, as other
1748         datasets may result in unphysical images.
1749    proj_style : string
1750         The method of projection--same as method keyword.  Deprecated as of
1751         version 3.0.2.  Please use method instead.
1752    window_size : float
1753         The size of the window in inches. Set to 8 by default.
1754    aspect : float
1755         The aspect ratio of the plot.  Set to None for 1.
1756    field_parameters : dictionary
1757         A dictionary of field parameters than can be accessed by derived
1758         fields.
1759    data_source: YTSelectionContainer object
1760         Object to be used for data selection. Defaults to ds.all_data(), a
1761         region covering the full domain
1762    buff_size: length 2 sequence
1763         Size of the buffer to use for the image, i.e. the number of resolution elements
1764         used.  Effectively sets a resolution limit to the image if buff_size is
1765         smaller than the finest gridding.
1766
1767    Examples
1768    --------
1769
1770    Create a projection plot with a width of 20 kiloparsecs centered on the
1771    center of the simulation box:
1772
1773    >>> from yt import load
1774    >>> ds = load("IsolateGalaxygalaxy0030/galaxy0030")
1775    >>> p = ProjectionPlot(ds, "z", ("gas", "density"), width=(20, "kpc"))
1776
1777    """
1778    _plot_type = "Projection"
1779    _frb_generator = FixedResolutionBuffer
1780
1781    def __init__(
1782        self,
1783        ds,
1784        axis,
1785        fields,
1786        center="c",
1787        width=None,
1788        axes_unit=None,
1789        weight_field=None,
1790        max_level=None,
1791        origin="center-window",
1792        right_handed=True,
1793        fontsize=18,
1794        field_parameters=None,
1795        data_source=None,
1796        method="integrate",
1797        proj_style=None,
1798        window_size=8.0,
1799        buff_size=(800, 800),
1800        aspect=None,
1801    ):
1802        axis = fix_axis(axis, ds)
1803        if ds.geometry in (
1804            "spherical",
1805            "cylindrical",
1806            "geographic",
1807            "internal_geographic",
1808        ):
1809            mylog.info("Setting origin='native' for %s geometry.", ds.geometry)
1810            origin = "native"
1811        if proj_style is not None:
1812            issue_deprecation_warning(
1813                "`proj_style` parameter is deprecated, use `method` instead.",
1814                removal="4.1.0",
1815            )
1816            method = proj_style
1817        # If a non-weighted integral projection, assure field-label reflects that
1818        if weight_field is None and method == "integrate":
1819            self.projected = True
1820        (bounds, center, display_center) = get_window_parameters(
1821            axis, center, width, ds
1822        )
1823        if field_parameters is None:
1824            field_parameters = {}
1825
1826        # We don't use the plot's data source for validation like in the other
1827        # plotting classes to avoid an exception
1828        test_data_source = ds.all_data()
1829        validate_mesh_fields(test_data_source, fields)
1830
1831        if isinstance(ds, YTSpatialPlotDataset):
1832            proj = ds.all_data()
1833            proj.axis = axis
1834            if proj.axis != ds.parameters["axis"]:
1835                raise RuntimeError(
1836                    f"Original projection axis is {ds.parameters['axis']}."
1837                )
1838            if weight_field is not None:
1839                proj.weight_field = proj._determine_fields(weight_field)[0]
1840            else:
1841                proj.weight_field = weight_field
1842            proj.center = center
1843        else:
1844            proj = ds.proj(
1845                fields,
1846                axis,
1847                weight_field=weight_field,
1848                center=center,
1849                data_source=data_source,
1850                field_parameters=field_parameters,
1851                method=method,
1852                max_level=max_level,
1853            )
1854        PWViewerMPL.__init__(
1855            self,
1856            proj,
1857            bounds,
1858            fields=fields,
1859            origin=origin,
1860            right_handed=right_handed,
1861            fontsize=fontsize,
1862            window_size=window_size,
1863            aspect=aspect,
1864            buff_size=buff_size,
1865        )
1866        if axes_unit is None:
1867            axes_unit = get_axes_unit(width, ds)
1868        self.set_axes_unit(axes_unit)
1869
1870
1871class OffAxisSlicePlot(PWViewerMPL):
1872    r"""Creates an off axis slice plot from a dataset
1873
1874    Given a ds object, a normal vector defining a slicing plane, and
1875    a field name string, this will return a PWViewerMPL object
1876    containing the plot.
1877
1878    The plot can be updated using one of the many helper functions
1879    defined in PlotWindow.
1880
1881    Parameters
1882    ----------
1883    ds : :class:`yt.data_objects.static_output.Dataset`
1884         This is the dataset object corresponding to the
1885         simulation output to be plotted.
1886    normal : a sequence of floats
1887         The vector normal to the slicing plane.
1888    fields : string
1889         The name of the field(s) to be plotted.
1890    center : A sequence of floats, a string, or a tuple.
1891         The coordinate of the center of the image. If set to 'c', 'center' or
1892         left blank, the plot is centered on the middle of the domain. If set to
1893         'max' or 'm', the center will be located at the maximum of the
1894         ('gas', 'density') field. Centering on the max or min of a specific
1895         field is supported by providing a tuple such as ("min","temperature") or
1896         ("max","dark_matter_density"). Units can be specified by passing in *center*
1897         as a tuple containing a coordinate and string unit name or by passing
1898         in a YTArray. If a list or unitless array is supplied, code units are
1899         assumed.
1900    width : tuple or a float.
1901         Width can have four different formats to support windows with variable
1902         x and y widths.  They are:
1903
1904         ==================================     =======================
1905         format                                 example
1906         ==================================     =======================
1907         (float, string)                        (10,'kpc')
1908         ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
1909         float                                  0.2
1910         (float, float)                         (0.2, 0.3)
1911         ==================================     =======================
1912
1913         For example, (10, 'kpc') requests a plot window that is 10 kiloparsecs
1914         wide in the x and y directions, ((10,'kpc'),(15,'kpc')) requests a
1915         window that is 10 kiloparsecs wide along the x axis and 15
1916         kiloparsecs wide along the y axis.  In the other two examples, code
1917         units are assumed, for example (0.2, 0.3) requests a plot that has an
1918         x width of 0.2 and a y width of 0.3 in code units.  If units are
1919         provided the resulting plot axis labels will use the supplied units.
1920    axes_unit : string
1921         The name of the unit for the tick labels on the x and y axes.
1922         Defaults to None, which automatically picks an appropriate unit.
1923         If axes_unit is '1', 'u', or 'unitary', it will not display the
1924         units, and only show the axes name.
1925    north_vector : a sequence of floats
1926         A vector defining the 'up' direction in the plot.  This
1927         option sets the orientation of the slicing plane.  If not
1928         set, an arbitrary grid-aligned north-vector is chosen.
1929    right_handed : boolean
1930         Whether the implicit east vector for the image generated is set to make a right
1931         handed coordinate system with the north vector and the normal, the direction of
1932         the 'window' into the data.
1933    fontsize : integer
1934         The size of the fonts for the axis, colorbar, and tick labels.
1935    field_parameters : dictionary
1936         A dictionary of field parameters than can be accessed by derived
1937         fields.
1938    data_source : YTSelectionContainer Object
1939         Object to be used for data selection.  Defaults ds.all_data(), a
1940         region covering the full domain.
1941    buff_size: length 2 sequence
1942         Size of the buffer to use for the image, i.e. the number of resolution elements
1943         used.  Effectively sets a resolution limit to the image if buff_size is
1944         smaller than the finest gridding.
1945    """
1946
1947    _plot_type = "OffAxisSlice"
1948    _frb_generator = FixedResolutionBuffer
1949
1950    def __init__(
1951        self,
1952        ds,
1953        normal,
1954        fields,
1955        center="c",
1956        width=None,
1957        axes_unit=None,
1958        north_vector=None,
1959        right_handed=True,
1960        fontsize=18,
1961        field_parameters=None,
1962        data_source=None,
1963        buff_size=(800, 800),
1964    ):
1965        (bounds, center_rot) = get_oblique_window_parameters(normal, center, width, ds)
1966        if field_parameters is None:
1967            field_parameters = {}
1968
1969        if isinstance(ds, YTSpatialPlotDataset):
1970            cutting = ds.all_data()
1971            cutting.axis = 4
1972            cutting._inv_mat = ds.parameters["_inv_mat"]
1973        else:
1974            cutting = ds.cutting(
1975                normal,
1976                center,
1977                north_vector=north_vector,
1978                field_parameters=field_parameters,
1979                data_source=data_source,
1980            )
1981            cutting.get_data(fields)
1982        validate_mesh_fields(cutting, fields)
1983        # Hard-coding the origin keyword since the other two options
1984        # aren't well-defined for off-axis data objects
1985        PWViewerMPL.__init__(
1986            self,
1987            cutting,
1988            bounds,
1989            fields=fields,
1990            origin="center-window",
1991            periodic=False,
1992            right_handed=right_handed,
1993            oblique=True,
1994            fontsize=fontsize,
1995            buff_size=buff_size,
1996        )
1997        if axes_unit is None:
1998            axes_unit = get_axes_unit(width, ds)
1999        self.set_axes_unit(axes_unit)
2000
2001
2002class OffAxisProjectionDummyDataSource:
2003    _type_name = "proj"
2004    _key_fields = []
2005
2006    def __init__(
2007        self,
2008        center,
2009        ds,
2010        normal_vector,
2011        width,
2012        fields,
2013        interpolated,
2014        weight=None,
2015        volume=None,
2016        no_ghost=False,
2017        le=None,
2018        re=None,
2019        north_vector=None,
2020        method="integrate",
2021        data_source=None,
2022    ):
2023        self.center = center
2024        self.ds = ds
2025        self.axis = 4  # always true for oblique data objects
2026        self.normal_vector = normal_vector
2027        self.width = width
2028        if data_source is None:
2029            self.dd = ds.all_data()
2030        else:
2031            self.dd = data_source
2032        fields = self.dd._determine_fields(fields)
2033        self.fields = fields
2034        self.interpolated = interpolated
2035        if weight is not None:
2036            weight = self.dd._determine_fields(weight)[0]
2037        self.weight_field = weight
2038        self.volume = volume
2039        self.no_ghost = no_ghost
2040        self.le = le
2041        self.re = re
2042        self.north_vector = north_vector
2043        self.method = method
2044        self.orienter = Orientation(normal_vector, north_vector=north_vector)
2045
2046    def _determine_fields(self, *args):
2047        return self.dd._determine_fields(*args)
2048
2049
2050class OffAxisProjectionPlot(PWViewerMPL):
2051    r"""Creates an off axis projection plot from a dataset
2052
2053    Given a ds object, a normal vector to project along, and
2054    a field name string, this will return a PWViewerMPL object
2055    containing the plot.
2056
2057    The plot can be updated using one of the many helper functions
2058    defined in PlotWindow.
2059
2060    Parameters
2061    ----------
2062    ds : :class:`yt.data_objects.static_output.Dataset`
2063        This is the dataset object corresponding to the
2064        simulation output to be plotted.
2065    normal : a sequence of floats
2066        The vector normal to the slicing plane.
2067    fields : string
2068        The name of the field(s) to be plotted.
2069    center : A sequence of floats, a string, or a tuple.
2070         The coordinate of the center of the image. If set to 'c', 'center' or
2071         left blank, the plot is centered on the middle of the domain. If set to
2072         'max' or 'm', the center will be located at the maximum of the
2073         ('gas', 'density') field. Centering on the max or min of a specific
2074         field is supported by providing a tuple such as ("min","temperature") or
2075         ("max","dark_matter_density"). Units can be specified by passing in *center*
2076         as a tuple containing a coordinate and string unit name or by passing
2077         in a YTArray. If a list or unitless array is supplied, code units are
2078         assumed.
2079    width : tuple or a float.
2080         Width can have four different formats to support windows with variable
2081         x and y widths.  They are:
2082
2083         ==================================     =======================
2084         format                                 example
2085         ==================================     =======================
2086         (float, string)                        (10,'kpc')
2087         ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
2088         float                                  0.2
2089         (float, float)                         (0.2, 0.3)
2090         ==================================     =======================
2091
2092         For example, (10, 'kpc') requests a plot window that is 10 kiloparsecs
2093         wide in the x and y directions, ((10,'kpc'),(15,'kpc')) requests a
2094         window that is 10 kiloparsecs wide along the x axis and 15
2095         kiloparsecs wide along the y axis.  In the other two examples, code
2096         units are assumed, for example (0.2, 0.3) requests a plot that has an
2097         x width of 0.2 and a y width of 0.3 in code units.  If units are
2098         provided the resulting plot axis labels will use the supplied units.
2099    depth : A tuple or a float
2100         A tuple containing the depth to project through and the string
2101         key of the unit: (width, 'unit').  If set to a float, code units
2102         are assumed
2103    weight_field : string
2104         The name of the weighting field.  Set to None for no weight.
2105    max_level: int
2106         The maximum level to project to.
2107    axes_unit : string
2108         The name of the unit for the tick labels on the x and y axes.
2109         Defaults to None, which automatically picks an appropriate unit.
2110         If axes_unit is '1', 'u', or 'unitary', it will not display the
2111         units, and only show the axes name.
2112    north_vector : a sequence of floats
2113         A vector defining the 'up' direction in the plot.  This
2114         option sets the orientation of the slicing plane.  If not
2115         set, an arbitrary grid-aligned north-vector is chosen.
2116    right_handed : boolean
2117         Whether the implicit east vector for the image generated is set to make a right
2118         handed coordinate system with the north vector and the normal, the direction of
2119         the 'window' into the data.
2120    fontsize : integer
2121         The size of the fonts for the axis, colorbar, and tick labels.
2122    method : string
2123         The method of projection.  Valid methods are:
2124
2125         "integrate" with no weight_field specified : integrate the requested
2126         field along the line of sight.
2127
2128         "integrate" with a weight_field specified : weight the requested
2129         field by the weighting field and integrate along the line of sight.
2130
2131         "sum" : This method is the same as integrate, except that it does not
2132         multiply by a path length when performing the integration, and is
2133         just a straight summation of the field along the given axis. WARNING:
2134         This should only be used for uniform resolution grid datasets, as other
2135         datasets may result in unphysical images.
2136    data_source: YTSelectionContainer object
2137         Object to be used for data selection. Defaults to ds.all_data(), a
2138         region covering the full domain
2139    buff_size: length 2 sequence
2140         Size of the buffer to use for the image, i.e. the number of resolution elements
2141         used.  Effectively sets a resolution limit to the image if buff_size is
2142         smaller than the finest gridding.
2143    """
2144    _plot_type = "OffAxisProjection"
2145    _frb_generator = OffAxisProjectionFixedResolutionBuffer
2146
2147    def __init__(
2148        self,
2149        ds,
2150        normal,
2151        fields,
2152        center="c",
2153        width=None,
2154        depth=(1, "1"),
2155        axes_unit=None,
2156        weight_field=None,
2157        max_level=None,
2158        north_vector=None,
2159        right_handed=True,
2160        volume=None,
2161        no_ghost=False,
2162        le=None,
2163        re=None,
2164        interpolated=False,
2165        fontsize=18,
2166        method="integrate",
2167        data_source=None,
2168        buff_size=(800, 800),
2169    ):
2170        (bounds, center_rot) = get_oblique_window_parameters(
2171            normal, center, width, ds, depth=depth
2172        )
2173        fields = list(iter_fields(fields))[:]
2174        oap_width = ds.arr(
2175            (bounds[1] - bounds[0], bounds[3] - bounds[2], bounds[5] - bounds[4])
2176        )
2177        OffAxisProj = OffAxisProjectionDummyDataSource(
2178            center_rot,
2179            ds,
2180            normal,
2181            oap_width,
2182            fields,
2183            interpolated,
2184            weight=weight_field,
2185            volume=volume,
2186            no_ghost=no_ghost,
2187            le=le,
2188            re=re,
2189            north_vector=north_vector,
2190            method=method,
2191            data_source=data_source,
2192        )
2193
2194        validate_mesh_fields(OffAxisProj, fields)
2195
2196        if max_level is not None:
2197            OffAxisProj.dd.max_level = max_level
2198
2199        # If a non-weighted, integral projection, assure field label
2200        # reflects that
2201        if weight_field is None and OffAxisProj.method == "integrate":
2202            self.projected = True
2203
2204        # Hard-coding the origin keyword since the other two options
2205        # aren't well-defined for off-axis data objects
2206        PWViewerMPL.__init__(
2207            self,
2208            OffAxisProj,
2209            bounds,
2210            fields=fields,
2211            origin="center-window",
2212            periodic=False,
2213            oblique=True,
2214            right_handed=right_handed,
2215            fontsize=fontsize,
2216            buff_size=buff_size,
2217        )
2218        if axes_unit is None:
2219            axes_unit = get_axes_unit(width, ds)
2220        self.set_axes_unit(axes_unit)
2221
2222
2223class WindowPlotMPL(ImagePlotMPL):
2224    """A container for a single PlotWindow matplotlib figure and axes"""
2225
2226    def __init__(
2227        self,
2228        data,
2229        cbname,
2230        cblinthresh,
2231        cmap,
2232        extent,
2233        zlim,
2234        figure_size,
2235        fontsize,
2236        aspect,
2237        figure,
2238        axes,
2239        cax,
2240        mpl_proj,
2241        mpl_transform,
2242    ):
2243        from matplotlib.ticker import ScalarFormatter
2244
2245        self._draw_colorbar = True
2246        self._draw_axes = True
2247        self._draw_frame = True
2248        self._fontsize = fontsize
2249        self._figure_size = figure_size
2250        self._projection = mpl_proj
2251        self._transform = mpl_transform
2252
2253        # Compute layout
2254        fontscale = float(fontsize) / 18.0
2255        if fontscale < 1.0:
2256            fontscale = np.sqrt(fontscale)
2257
2258        if is_sequence(figure_size):
2259            fsize = figure_size[0]
2260        else:
2261            fsize = figure_size
2262        self._cb_size = 0.0375 * fsize
2263        self._ax_text_size = [1.2 * fontscale, 0.9 * fontscale]
2264        self._top_buff_size = 0.30 * fontscale
2265        self._aspect = ((extent[1] - extent[0]) / (extent[3] - extent[2])).in_cgs()
2266        self._unit_aspect = aspect
2267
2268        size, axrect, caxrect = self._get_best_layout()
2269
2270        super().__init__(size, axrect, caxrect, zlim, figure, axes, cax)
2271
2272        self._init_image(data, cbname, cblinthresh, cmap, extent, aspect)
2273
2274        # In matplotlib 2.1 and newer we'll be able to do this using
2275        # self.image.axes.ticklabel_format
2276        # See https://github.com/matplotlib/matplotlib/pull/6337
2277        formatter = ScalarFormatter(useMathText=True)
2278        formatter.set_scientific(True)
2279        formatter.set_powerlimits((-2, 3))
2280        self.image.axes.xaxis.set_major_formatter(formatter)
2281        self.image.axes.yaxis.set_major_formatter(formatter)
2282        if cbname == "linear":
2283            self.cb.formatter.set_scientific(True)
2284            try:
2285                self.cb.formatter.set_useMathText(True)
2286            except AttributeError:
2287                # this is only available in mpl > 2.1
2288                pass
2289            self.cb.formatter.set_powerlimits((-2, 3))
2290            self.cb.update_ticks()
2291
2292    def _create_axes(self, axrect):
2293        self.axes = self.figure.add_axes(axrect, projection=self._projection)
2294
2295
2296def SlicePlot(ds, normal=None, fields=None, axis=None, *args, **kwargs):
2297    r"""
2298    A factory function for
2299    :class:`yt.visualization.plot_window.AxisAlignedSlicePlot`
2300    and :class:`yt.visualization.plot_window.OffAxisSlicePlot` objects.  This
2301    essentially allows for a single entry point to both types of slice plots,
2302    the distinction being determined by the specified normal vector to the
2303    slice.
2304
2305    The returned plot object can be updated using one of the many helper
2306    functions defined in PlotWindow.
2307
2308    Parameters
2309    ----------
2310
2311    ds : :class:`yt.data_objects.static_output.Dataset`
2312        This is the dataset object corresponding to the
2313        simulation output to be plotted.
2314    normal : int or one of 'x', 'y', 'z', or sequence of floats
2315        This specifies the normal vector to the slice.  If given as an integer
2316        or a coordinate string (0=x, 1=y, 2=z), this function will return an
2317        :class:`AxisAlignedSlicePlot` object.  If given as a sequence of floats,
2318        this is interpreted as an off-axis vector and an
2319        :class:`OffAxisSlicePlot` object is returned.
2320    fields : string
2321         The name of the field(s) to be plotted.
2322    axis : int or one of 'x', 'y', 'z'
2323         An int corresponding to the axis to slice along (0=x, 1=y, 2=z)
2324         or the axis name itself.  If specified, this will replace normal.
2325
2326
2327    The following are nominally keyword arguments passed onto the respective
2328    slice plot objects generated by this function.
2329
2330    Keyword Arguments
2331    -----------------
2332
2333    center : A sequence floats, a string, or a tuple.
2334         The coordinate of the center of the image. If set to 'c', 'center' or
2335         left blank, the plot is centered on the middle of the domain. If set to
2336         'max' or 'm', the center will be located at the maximum of the
2337         ('gas', 'density') field. Centering on the max or min of a specific
2338         field is supported by providing a tuple such as ("min","temperature") or
2339         ("max","dark_matter_density"). Units can be specified by passing in *center*
2340         as a tuple containing a coordinate and string unit name or by passing
2341         in a YTArray. If a list or unitless array is supplied, code units are
2342         assumed.
2343    width : tuple or a float.
2344         Width can have four different formats to support windows with variable
2345         x and y widths.  They are:
2346
2347         ==================================     =======================
2348         format                                 example
2349         ==================================     =======================
2350         (float, string)                        (10,'kpc')
2351         ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
2352         float                                  0.2
2353         (float, float)                         (0.2, 0.3)
2354         ==================================     =======================
2355
2356         For example, (10, 'kpc') requests a plot window that is 10 kiloparsecs
2357         wide in the x and y directions, ((10,'kpc'),(15,'kpc')) requests a
2358         window that is 10 kiloparsecs wide along the x axis and 15
2359         kiloparsecs wide along the y axis.  In the other two examples, code
2360         units are assumed, for example (0.2, 0.3) requests a plot that has an
2361         x width of 0.2 and a y width of 0.3 in code units.  If units are
2362         provided the resulting plot axis labels will use the supplied units.
2363    axes_unit : string
2364         The name of the unit for the tick labels on the x and y axes.
2365         Defaults to None, which automatically picks an appropriate unit.
2366         If axes_unit is '1', 'u', or 'unitary', it will not display the
2367         units, and only show the axes name.
2368    origin : string or length 1, 2, or 3 sequence.
2369         The location of the origin of the plot coordinate system for
2370         `AxisAlignedSlicePlot` object; for `OffAxisSlicePlot` objects this
2371         parameter is discarded. This is typically represented by a '-'
2372         separated string or a tuple of strings. In the first index the
2373         y-location is given by 'lower', 'upper', or 'center'. The second index
2374         is the x-location, given as 'left', 'right', or 'center'. Finally, the
2375         whether the origin is applied in 'domain' space, plot 'window' space or
2376         'native' simulation coordinate system is given. For example, both
2377         'upper-right-domain' and ['upper', 'right', 'domain'] place the
2378         origin in the upper right hand corner of domain space. If x or y
2379         are not given, a value is inferred. For instance, 'left-domain'
2380         corresponds to the lower-left hand corner of the simulation domain,
2381         'center-domain' corresponds to the center of the simulation domain,
2382         or 'center-window' for the center of the plot window. In the event
2383         that none of these options place the origin in a desired location,
2384         a sequence of tuples and a string specifying the
2385         coordinate space can be given. If plain numeric types are input,
2386         units of `code_length` are assumed. Further examples:
2387
2388         =============================================== ===============================
2389         format                                          example
2390         =============================================== ===============================
2391         '{space}'                                       'domain'
2392         '{xloc}-{space}'                                'left-window'
2393         '{yloc}-{space}'                                'upper-domain'
2394         '{yloc}-{xloc}-{space}'                         'lower-right-window'
2395         ('{space}',)                                    ('window',)
2396         ('{xloc}', '{space}')                           ('right', 'domain')
2397         ('{yloc}', '{space}')                           ('lower', 'window')
2398         ('{yloc}', '{xloc}', '{space}')                 ('lower', 'right', 'window')
2399         ((yloc, '{unit}'), (xloc, '{unit}'), '{space}') ((0, 'm'), (.4, 'm'), 'window')
2400         (xloc, yloc, '{space}')                         (0.23, 0.5, 'domain')
2401         =============================================== ===============================
2402    north_vector : a sequence of floats
2403        A vector defining the 'up' direction in the `OffAxisSlicePlot`; not
2404        used in `AxisAlignedSlicePlot`.  This option sets the orientation of the
2405        slicing plane.  If not set, an arbitrary grid-aligned north-vector is
2406        chosen.
2407    fontsize : integer
2408         The size of the fonts for the axis, colorbar, and tick labels.
2409    field_parameters : dictionary
2410         A dictionary of field parameters than can be accessed by derived
2411         fields.
2412    data_source : YTSelectionContainer Object
2413         Object to be used for data selection.  Defaults to a region covering
2414         the entire simulation.
2415
2416    Raises
2417    ------
2418
2419    AssertionError
2420        If a proper normal axis is not specified via the normal or axis
2421        keywords, and/or if a field to plot is not specified.
2422
2423    Examples
2424    --------
2425
2426    >>> from yt import load
2427    >>> ds = load("IsolatedGalaxy/galaxy0030/galaxy0030")
2428    >>> slc = SlicePlot(ds, "x", ("gas", "density"), center=[0.2, 0.3, 0.4])
2429
2430    >>> slc = SlicePlot(
2431    ...     ds, [0.4, 0.2, -0.1], ("gas", "pressure"), north_vector=[0.2, -0.3, 0.1]
2432    ... )
2433
2434    """
2435    if axis is not None:
2436        issue_deprecation_warning(
2437            "SlicePlot's argument 'axis' is a deprecated alias for 'normal', it "
2438            "will be removed in a future version of yt.",
2439            since="4.0.0",
2440            removal="4.1.0",
2441        )
2442        if normal is not None:
2443            raise TypeError(
2444                "SlicePlot() received incompatible arguments 'axis' and 'normal'"
2445            )
2446        normal = axis
2447
2448    # to keep positional ordering we had to make 'normal' and 'fields' keywords
2449    if normal is None:
2450        raise TypeError("Missing argument in SlicePlot(): 'normal'")
2451
2452    if fields is None:
2453        raise TypeError("Missing argument in SlicePlot(): 'fields'")
2454
2455    # use an AxisAlignedSlicePlot where possible, e.g.:
2456    # maybe someone passed normal=[0,0,0.2] when they should have just used "z"
2457    if is_sequence(normal) and not isinstance(normal, str):
2458        if np.count_nonzero(normal) == 1:
2459            normal = ("x", "y", "z")[np.nonzero(normal)[0][0]]
2460        else:
2461            normal = np.array(normal, dtype="float64")
2462            np.divide(normal, np.dot(normal, normal), normal)
2463
2464    # by now the normal should be properly set to get either a On/Off Axis plot
2465    if is_sequence(normal) and not isinstance(normal, str):
2466        # OffAxisSlicePlot has hardcoded origin; remove it if in kwargs
2467        if "origin" in kwargs:
2468            mylog.warning(
2469                "Ignoring 'origin' keyword as it is ill-defined for "
2470                "an OffAxisSlicePlot object."
2471            )
2472            del kwargs["origin"]
2473
2474        return OffAxisSlicePlot(ds, normal, fields, *args, **kwargs)
2475    else:
2476        # north_vector not used in AxisAlignedSlicePlots; remove it if in kwargs
2477        if "north_vector" in kwargs:
2478            mylog.warning(
2479                "Ignoring 'north_vector' keyword as it is ill-defined for "
2480                "an AxisAlignedSlicePlot object."
2481            )
2482            del kwargs["north_vector"]
2483
2484        return AxisAlignedSlicePlot(ds, normal, fields, *args, **kwargs)
2485
2486
2487def plot_2d(
2488    ds,
2489    fields,
2490    center="c",
2491    width=None,
2492    axes_unit=None,
2493    origin="center-window",
2494    fontsize=18,
2495    field_parameters=None,
2496    window_size=8.0,
2497    aspect=None,
2498    data_source=None,
2499):
2500    r"""Creates a plot of a 2D dataset
2501
2502    Given a ds object and a field name string, this will return a
2503    PWViewerMPL object containing the plot.
2504
2505    The plot can be updated using one of the many helper functions
2506    defined in PlotWindow.
2507
2508    Parameters
2509    ----------
2510    ds : `Dataset`
2511         This is the dataset object corresponding to the
2512         simulation output to be plotted.
2513    fields : string
2514         The name of the field(s) to be plotted.
2515    center : A sequence of floats, a string, or a tuple.
2516         The coordinate of the center of the image. If set to 'c', 'center' or
2517         left blank, the plot is centered on the middle of the domain. If set to
2518         'max' or 'm', the center will be located at the maximum of the
2519         ('gas', 'density') field. Centering on the max or min of a specific
2520         field is supported by providing a tuple such as ("min","temperature") or
2521         ("max","dark_matter_density"). Units can be specified by passing in *center*
2522         as a tuple containing a coordinate and string unit name or by passing
2523         in a YTArray. If a list or unitless array is supplied, code units are
2524         assumed. For plot_2d, this keyword accepts a coordinate in two dimensions.
2525    width : tuple or a float.
2526         Width can have four different formats to support windows with variable
2527         x and y widths.  They are:
2528
2529         ==================================     =======================
2530         format                                 example
2531         ==================================     =======================
2532         (float, string)                        (10,'kpc')
2533         ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
2534         float                                  0.2
2535         (float, float)                         (0.2, 0.3)
2536         ==================================     =======================
2537
2538         For example, (10, 'kpc') requests a plot window that is 10 kiloparsecs
2539         wide in the x and y directions, ((10,'kpc'),(15,'kpc')) requests a
2540         window that is 10 kiloparsecs wide along the x axis and 15
2541         kiloparsecs wide along the y axis.  In the other two examples, code
2542         units are assumed, for example (0.2, 0.3) requests a plot that has an
2543         x width of 0.2 and a y width of 0.3 in code units.  If units are
2544         provided the resulting plot axis labels will use the supplied units.
2545    origin : string or length 1, 2, or 3 sequence.
2546         The location of the origin of the plot coordinate system. This
2547         is typically represented by a '-' separated string or a tuple of
2548         strings. In the first index the y-location is given by 'lower',
2549         'upper', or 'center'. The second index is the x-location, given as
2550         'left', 'right', or 'center'. Finally, whether the origin is
2551         applied in 'domain' space, plot 'window' space or 'native'
2552         simulation coordinate system is given. For example, both
2553         'upper-right-domain' and ['upper', 'right', 'domain'] place the
2554         origin in the upper right hand corner of domain space. If x or y
2555         are not given, a value is inferred. For instance, 'left-domain'
2556         corresponds to the lower-left hand corner of the simulation domain,
2557         'center-domain' corresponds to the center of the simulation domain,
2558         or 'center-window' for the center of the plot window. In the event
2559         that none of these options place the origin in a desired location,
2560         a sequence of tuples and a string specifying the
2561         coordinate space can be given. If plain numeric types are input,
2562         units of `code_length` are assumed. Further examples:
2563
2564         =============================================== ===============================
2565         format                                          example
2566         =============================================== ===============================
2567         '{space}'                                       'domain'
2568         '{xloc}-{space}'                                'left-window'
2569         '{yloc}-{space}'                                'upper-domain'
2570         '{yloc}-{xloc}-{space}'                         'lower-right-window'
2571         ('{space}',)                                    ('window',)
2572         ('{xloc}', '{space}')                           ('right', 'domain')
2573         ('{yloc}', '{space}')                           ('lower', 'window')
2574         ('{yloc}', '{xloc}', '{space}')                 ('lower', 'right', 'window')
2575         ((yloc, '{unit}'), (xloc, '{unit}'), '{space}') ((0, 'm'), (.4, 'm'), 'window')
2576         (xloc, yloc, '{space}')                         (0.23, 0.5, 'domain')
2577         =============================================== ===============================
2578    axes_unit : string
2579         The name of the unit for the tick labels on the x and y axes.
2580         Defaults to None, which automatically picks an appropriate unit.
2581         If axes_unit is '1', 'u', or 'unitary', it will not display the
2582         units, and only show the axes name.
2583    fontsize : integer
2584         The size of the fonts for the axis, colorbar, and tick labels.
2585    field_parameters : dictionary
2586         A dictionary of field parameters than can be accessed by derived
2587         fields.
2588    data_source: YTSelectionContainer object
2589         Object to be used for data selection. Defaults to ds.all_data(), a
2590         region covering the full domain
2591    """
2592    if ds.dimensionality != 2:
2593        raise RuntimeError("plot_2d only plots 2D datasets!")
2594    if ds.geometry in ["cartesian", "polar", "spectral_cube"]:
2595        axis = "z"
2596    elif ds.geometry == "cylindrical":
2597        axis = "theta"
2598    elif ds.geometry == "spherical":
2599        axis = "phi"
2600    else:
2601        raise NotImplementedError(
2602            f"plot_2d does not yet support datasets with {ds.geometry} geometries"
2603        )
2604    # Part of the convenience of plot_2d is to eliminate the use of the
2605    # superfluous coordinate, so we do that also with the center argument
2606    if not isinstance(center, str) and obj_length(center) == 2:
2607        c0_string = isinstance(center[0], str)
2608        c1_string = isinstance(center[1], str)
2609        if not c0_string and not c1_string:
2610            if obj_length(center[0]) == 2 and c1_string:
2611                center = ds.arr(center[0], center[1])
2612            elif not isinstance(center, YTArray):
2613                center = ds.arr(center, "code_length")
2614            center.convert_to_units("code_length")
2615        center = ds.arr([center[0], center[1], ds.domain_center[2]])
2616    return AxisAlignedSlicePlot(
2617        ds,
2618        axis,
2619        fields,
2620        center=center,
2621        width=width,
2622        axes_unit=axes_unit,
2623        origin=origin,
2624        fontsize=fontsize,
2625        field_parameters=field_parameters,
2626        window_size=window_size,
2627        aspect=aspect,
2628        data_source=data_source,
2629    )
2630