1import re
2import warnings
3from functools import wraps
4from numbers import Number
5
6import matplotlib
7import numpy as np
8
9from yt.data_objects.data_containers import YTDataContainer
10from yt.data_objects.level_sets.clump_handling import Clump
11from yt.data_objects.selection_objects.cut_region import YTCutRegion
12from yt.data_objects.static_output import Dataset
13from yt.frontends.ytdata.data_structures import YTClumpContainer
14from yt.funcs import is_sequence, mylog, validate_width_tuple
15from yt.geometry.geometry_handler import is_curvilinear
16from yt.geometry.unstructured_mesh_handler import UnstructuredIndex
17from yt.units import dimensions
18from yt.units.yt_array import YTArray, YTQuantity, uhstack
19from yt.utilities.exceptions import YTDataTypeUnsupported
20from yt.utilities.lib.geometry_utils import triangle_plane_intersect
21from yt.utilities.lib.line_integral_convolution import line_integral_convolution_2d
22from yt.utilities.lib.mesh_triangulation import triangulate_indices
23from yt.utilities.lib.pixelization_routines import (
24    pixelize_cartesian,
25    pixelize_off_axis_cartesian,
26)
27from yt.utilities.math_utils import periodic_ray
28from yt.utilities.on_demand_imports import NotAModule
29from yt.visualization.image_writer import apply_colormap
30
31callback_registry = {}
32
33
34def _verify_geometry(func):
35    @wraps(func)
36    def _check_geometry(self, plot):
37        geom = plot.data.ds.coordinates.name
38        supp = self._supported_geometries
39        cs = getattr(self, "coord_system", None)
40        if supp is None or geom in supp:
41            return func(self, plot)
42        if cs in ("axis", "figure") and "force" not in supp:
43            return func(self, plot)
44        raise YTDataTypeUnsupported(geom, supp)
45
46    return _check_geometry
47
48
49class PlotCallback:
50    # _supported_geometries is set by subclasses of PlotCallback to a tuple of
51    # strings corresponding to the names of the geometries that a callback
52    # supports.  By default it is None, which means it supports everything.
53    # Note that if there's a coord_system parameter that is set to "axis" or
54    # "figure" this is disregarded.  If "force" is included in the tuple, it
55    # will *not* check whether or not the coord_system is in axis or figure,
56    # and will only look at the geometries.
57    _supported_geometries = None
58
59    def __init_subclass__(cls, *args, **kwargs):
60        super().__init_subclass__(*args, **kwargs)
61        callback_registry[cls.__name__] = cls
62        cls.__call__ = _verify_geometry(cls.__call__)
63
64    def __init__(self, *args, **kwargs):
65        pass
66
67    def __call__(self, plot):
68        raise NotImplementedError
69
70    def _project_coords(self, plot, coord):
71        """
72        Convert coordinates from simulation data coordinates to projected
73        data coordinates.  Simulation data coordinates are three dimensional,
74        and can either be specified as a YTArray or as a list or array in
75        code_length units.  Projected data units are 2D versions of the
76        simulation data units relative to the axes of the final plot.
77        """
78        if len(coord) == 3:
79            if not isinstance(coord, YTArray):
80                coord = plot.data.ds.arr(coord, "code_length")
81            coord.convert_to_units("code_length")
82            ax = plot.data.axis
83            # if this is an on-axis projection or slice, then
84            # just grab the appropriate 2 coords for the on-axis view
85            if ax >= 0 and ax <= 2:
86                (xi, yi) = (
87                    plot.data.ds.coordinates.x_axis[ax],
88                    plot.data.ds.coordinates.y_axis[ax],
89                )
90                coord = (coord[xi], coord[yi])
91
92            # if this is an off-axis project or slice (ie cutting plane)
93            # we have to calculate where the data coords fall in the projected
94            # plane
95            elif ax == 4:
96                # transpose is just to get [[x1,x2,...],[y1,y2,...],[z1,z2,...]]
97                # in the same order as plot.data.center for array arithmetic
98                coord_vectors = coord.transpose() - plot.data.center
99                x = np.dot(coord_vectors, plot.data.orienter.unit_vectors[1])
100                y = np.dot(coord_vectors, plot.data.orienter.unit_vectors[0])
101                # Transpose into image coords. Due to VR being not a
102                # right-handed coord system
103                coord = (y, x)
104            else:
105                raise ValueError("Object being plot must have a `data.axis` defined")
106
107        # if the position is already two-coords, it is expected to be
108        # in the proper projected orientation
109        else:
110            raise ValueError("'data' coordinates must be 3 dimensions")
111        return coord
112
113    def _convert_to_plot(self, plot, coord, offset=True):
114        """
115        Convert coordinates from projected data coordinates to PlotWindow
116        plot coordinates.  Projected data coordinates are two dimensional
117        and refer to the location relative to the specific axes being plotted,
118        although still in simulation units.  PlotWindow plot coordinates
119        are locations as found in the final plot, usually with the origin
120        in the center of the image and the extent of the image defined by
121        the final plot axis markers.
122        """
123        # coord should be a 2 x ncoord array-like datatype.
124        try:
125            ncoord = np.array(coord).shape[1]
126        except IndexError:
127            ncoord = 1
128
129        # Convert the data and plot limits to tiled numpy arrays so that
130        # convert_to_plot is automatically vectorized.
131
132        x0 = np.array(np.tile(plot.xlim[0].to("code_length"), ncoord))
133        x1 = np.array(np.tile(plot.xlim[1].to("code_length"), ncoord))
134        xx0 = np.tile(plot._axes.get_xlim()[0], ncoord)
135        xx1 = np.tile(plot._axes.get_xlim()[1], ncoord)
136
137        y0 = np.array(np.tile(plot.ylim[0].to("code_length"), ncoord))
138        y1 = np.array(np.tile(plot.ylim[1].to("code_length"), ncoord))
139        yy0 = np.tile(plot._axes.get_ylim()[0], ncoord)
140        yy1 = np.tile(plot._axes.get_ylim()[1], ncoord)
141
142        try:
143            ccoord = np.array(coord.to("code_length"))
144        except AttributeError:
145            ccoord = np.array(coord)
146
147        # We need a special case for when we are only given one coordinate.
148        if ccoord.shape == (2,):
149            return np.array(
150                [
151                    ((ccoord[0] - x0) / (x1 - x0) * (xx1 - xx0) + xx0)[0],
152                    ((ccoord[1] - y0) / (y1 - y0) * (yy1 - yy0) + yy0)[0],
153                ]
154            )
155        else:
156            return np.array(
157                [
158                    (ccoord[0][:] - x0) / (x1 - x0) * (xx1 - xx0) + xx0,
159                    (ccoord[1][:] - y0) / (y1 - y0) * (yy1 - yy0) + yy0,
160                ]
161            )
162
163    def _sanitize_coord_system(self, plot, coord, coord_system):
164        """
165        Given a set of one or more x,y (and z) coordinates and a coordinate
166        system, convert the coordinates (and transformation) ready for final
167        plotting.
168
169        Parameters
170        ----------
171
172        plot: a PlotMPL subclass
173           The plot that we are converting coordinates for
174
175        coord: array-like
176           Coordinates in some coordinate system: [x,y,z].
177           Alternatively, can specify multiple coordinates as:
178           [[x1,x2,...,xn], [y1, y2,...,yn], [z1,z2,...,zn]]
179
180        coord_system: string
181
182            Possible values include:
183
184            * ``'data'``
185                3D data coordinates relative to original dataset
186
187            * ``'plot'``
188                2D coordinates as defined by the final axis locations
189
190            * ``'axis'``
191                2D coordinates within the axis object from (0,0) in lower left
192                to (1,1) in upper right.  Same as matplotlib axis coords.
193
194            * ``'figure'``
195                2D coordinates within figure object from (0,0) in lower left
196                to (1,1) in upper right.  Same as matplotlib figure coords.
197        """
198        # Assure coords are either a YTArray or numpy array
199        coord = np.asanyarray(coord, dtype="float64")
200        # if in data coords, project them to plot coords
201        if coord_system == "data":
202            if len(coord) < 3:
203                raise ValueError(
204                    "Coordinates in 'data' coordinate system need to be in 3D"
205                )
206            coord = self._project_coords(plot, coord)
207            coord = self._convert_to_plot(plot, coord)
208        # if in plot coords, define the transform correctly
209        if coord_system == "data" or coord_system == "plot":
210            self.transform = plot._axes.transData
211            return coord
212        # if in axis coords, define the transform correctly
213        if coord_system == "axis":
214            self.transform = plot._axes.transAxes
215            if len(coord) > 2:
216                raise ValueError(
217                    "Coordinates in 'axis' coordinate system need to be in 2D"
218                )
219            return coord
220        # if in figure coords, define the transform correctly
221        elif coord_system == "figure":
222            self.transform = plot._figure.transFigure
223            return coord
224        else:
225            raise ValueError(
226                "Argument coord_system must have a value of "
227                "'data', 'plot', 'axis', or 'figure'."
228            )
229
230    def _physical_bounds(self, plot):
231        xlims = tuple(v.in_units("code_length") for v in plot.xlim)
232        ylims = tuple(v.in_units("code_length") for v in plot.ylim)
233        return xlims + ylims
234
235    def _plot_bounds(self, plot):
236        return plot._axes.get_xlim() + plot._axes.get_ylim()
237
238    def _pixel_scale(self, plot):
239        x0, x1, y0, y1 = self._physical_bounds(plot)
240        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
241        dx = (xx1 - xx0) / (x1 - x0)
242        dy = (yy1 - yy0) / (y1 - y0)
243        return dx, dy
244
245    def _set_font_properties(self, plot, labels, **kwargs):
246        """
247        This sets all of the text instances created by a callback to have
248        the same font size and properties as all of the other fonts in the
249        figure.  If kwargs are set, they override the defaults.
250        """
251        # This is a little messy because there is no trivial way to update
252        # a MPL.font_manager.FontProperties object with new attributes
253        # aside from setting them individually.  So we pick out the relevant
254        # MPL.Text() kwargs from the local kwargs and let them override the
255        # defaults.
256        local_font_properties = plot.font_properties.copy()
257
258        # Turn off the default TT font file, otherwise none of this works.
259        local_font_properties.set_file(None)
260        local_font_properties.set_family("stixgeneral")
261
262        if "family" in kwargs:
263            local_font_properties.set_family(kwargs["family"])
264        if "file" in kwargs:
265            local_font_properties.set_file(kwargs["file"])
266        if "fontconfig_pattern" in kwargs:
267            local_font_properties.set_fontconfig_pattern(kwargs["fontconfig_pattern"])
268        if "name" in kwargs:
269            local_font_properties.set_name(kwargs["name"])
270        if "size" in kwargs:
271            local_font_properties.set_size(kwargs["size"])
272        if "slant" in kwargs:
273            local_font_properties.set_slant(kwargs["slant"])
274        if "stretch" in kwargs:
275            local_font_properties.set_stretch(kwargs["stretch"])
276        if "style" in kwargs:
277            local_font_properties.set_style(kwargs["style"])
278        if "variant" in kwargs:
279            local_font_properties.set_variant(kwargs["variant"])
280        if "weight" in kwargs:
281            local_font_properties.set_weight(kwargs["weight"])
282
283        # For each label, set the font properties and color to the figure
284        # defaults if not already set in the callback itself
285        for label in labels:
286            if plot.font_color is not None and "color" not in kwargs:
287                label.set_color(plot.font_color)
288            label.set_fontproperties(local_font_properties)
289
290
291class VelocityCallback(PlotCallback):
292    """
293    Adds a 'quiver' plot of velocity to the plot, skipping all but
294    every *factor* datapoint. *scale* is the data units per arrow
295    length unit using *scale_units* and *plot_args* allows you to
296    pass in matplotlib arguments (see matplotlib.axes.Axes.quiver
297    for more info). if *normalize* is True, the velocity fields
298    will be scaled by their local (in-plane) length, allowing
299    morphological features to be more clearly seen for fields
300    with substantial variation in field strength.
301    """
302
303    _type_name = "velocity"
304    _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical")
305
306    def __init__(
307        self, factor=16, scale=None, scale_units=None, normalize=False, plot_args=None
308    ):
309        PlotCallback.__init__(self)
310        self.factor = factor
311        self.scale = scale
312        self.scale_units = scale_units
313        self.normalize = normalize
314        if plot_args is None:
315            plot_args = {}
316        self.plot_args = plot_args
317
318    def __call__(self, plot):
319        ftype = plot.data._current_fluid_type
320        # Instantiation of these is cheap
321        if plot._type_name == "CuttingPlane":
322            if is_curvilinear(plot.data.ds.geometry):
323                raise NotImplementedError(
324                    "Velocity annotation for cutting \
325                    plane is not supported for %s geometry"
326                    % plot.data.ds.geometry
327                )
328        if plot._type_name == "CuttingPlane":
329            qcb = CuttingQuiverCallback(
330                (ftype, "cutting_plane_velocity_x"),
331                (ftype, "cutting_plane_velocity_y"),
332                self.factor,
333                scale=self.scale,
334                normalize=self.normalize,
335                scale_units=self.scale_units,
336                plot_args=self.plot_args,
337            )
338        else:
339            xax = plot.data.ds.coordinates.x_axis[plot.data.axis]
340            yax = plot.data.ds.coordinates.y_axis[plot.data.axis]
341            axis_names = plot.data.ds.coordinates.axis_name
342
343            bv = plot.data.get_field_parameter("bulk_velocity")
344            if bv is not None:
345                bv_x = bv[xax]
346                bv_y = bv[yax]
347            else:
348                bv_x = bv_y = 0
349
350            if (
351                plot.data.ds.geometry in ["polar", "cylindrical"]
352                and axis_names[plot.data.axis] == "z"
353            ):
354                # polar_z and cyl_z is aligned with carteian_z
355                # should convert r-theta plane to x-y plane
356                xv = (ftype, "velocity_cartesian_x")
357                yv = (ftype, "velocity_cartesian_y")
358            else:
359                # for other cases (even for cylindrical geometry),
360                # orthogonal planes are generically Cartesian
361                xv = (ftype, f"velocity_{axis_names[xax]}")
362                yv = (ftype, f"velocity_{axis_names[yax]}")
363
364            # determine the full fields including field type
365            xv = plot.data._determine_fields(xv)[0]
366            yv = plot.data._determine_fields(yv)[0]
367
368            qcb = QuiverCallback(
369                xv,
370                yv,
371                self.factor,
372                scale=self.scale,
373                scale_units=self.scale_units,
374                normalize=self.normalize,
375                bv_x=bv_x,
376                bv_y=bv_y,
377                plot_args=self.plot_args,
378            )
379        return qcb(plot)
380
381
382class MagFieldCallback(PlotCallback):
383    """
384    Adds a 'quiver' plot of magnetic field to the plot, skipping all but
385    every *factor* datapoint. *scale* is the data units per arrow
386    length unit using *scale_units* and *plot_args* allows you to pass
387    in matplotlib arguments (see matplotlib.axes.Axes.quiver for more info).
388    if *normalize* is True, the magnetic fields will be scaled by their
389    local (in-plane) length, allowing morphological features to be more
390    clearly seen for fields with substantial variation in field strength.
391    """
392
393    _type_name = "magnetic_field"
394    _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical")
395
396    def __init__(
397        self, factor=16, scale=None, scale_units=None, normalize=False, plot_args=None
398    ):
399        PlotCallback.__init__(self)
400        self.factor = factor
401        self.scale = scale
402        self.scale_units = scale_units
403        self.normalize = normalize
404        if plot_args is None:
405            plot_args = {}
406        self.plot_args = plot_args
407
408    def __call__(self, plot):
409        ftype = plot.data._current_fluid_type
410        # Instantiation of these is cheap
411        if plot._type_name == "CuttingPlane":
412            if is_curvilinear(plot.data.ds.geometry):
413                raise NotImplementedError(
414                    "Magnetic field annotation for cutting \
415                    plane is not supported for %s geometry"
416                    % plot.data.ds.geometry
417                )
418            qcb = CuttingQuiverCallback(
419                (ftype, "cutting_plane_magnetic_field_x"),
420                (ftype, "cutting_plane_magnetic_field_y"),
421                self.factor,
422                scale=self.scale,
423                scale_units=self.scale_units,
424                normalize=self.normalize,
425                plot_args=self.plot_args,
426            )
427        else:
428            xax = plot.data.ds.coordinates.x_axis[plot.data.axis]
429            yax = plot.data.ds.coordinates.y_axis[plot.data.axis]
430            axis_names = plot.data.ds.coordinates.axis_name
431
432            if (
433                plot.data.ds.geometry in ["polar", "cylindrical"]
434                and axis_names[plot.data.axis] == "z"
435            ):
436                # polar_z and cyl_z is aligned with carteian_z
437                # should convert r-theta plane to x-y plane
438                xv = (ftype, "magnetic_field_cartesian_x")
439                yv = (ftype, "magnetic_field_cartesian_y")
440            else:
441                # for other cases (even for cylindrical geometry),
442                # orthogonal planes are generically Cartesian
443                xv = (ftype, f"magnetic_field_{axis_names[xax]}")
444                yv = (ftype, f"magnetic_field_{axis_names[yax]}")
445
446            qcb = QuiverCallback(
447                xv,
448                yv,
449                self.factor,
450                scale=self.scale,
451                scale_units=self.scale_units,
452                normalize=self.normalize,
453                plot_args=self.plot_args,
454            )
455        return qcb(plot)
456
457
458class QuiverCallback(PlotCallback):
459    """
460    Adds a 'quiver' plot to any plot, using the *field_x* and *field_y*
461    from the associated data, skipping every *factor* datapoints.
462    *scale* is the data units per arrow length unit using *scale_units*
463    and *plot_args* allows you to pass in matplotlib arguments (see
464    matplotlib.axes.Axes.quiver for more info). if *normalize* is True,
465    the fields will be scaled by their local (in-plane) length, allowing
466    morphological features to be more clearly seen for fields with
467    substantial variation in field strength.
468    """
469
470    _type_name = "quiver"
471    _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical")
472
473    def __init__(
474        self,
475        field_x,
476        field_y,
477        factor=16,
478        scale=None,
479        scale_units=None,
480        normalize=False,
481        bv_x=0,
482        bv_y=0,
483        plot_args=None,
484    ):
485        PlotCallback.__init__(self)
486        self.field_x = field_x
487        self.field_y = field_y
488        self.bv_x = bv_x
489        self.bv_y = bv_y
490        self.factor = factor
491        self.scale = scale
492        self.scale_units = scale_units
493        self.normalize = normalize
494        if plot_args is None:
495            plot_args = {}
496        self.plot_args = plot_args
497
498    def __call__(self, plot):
499        x0, x1, y0, y1 = self._physical_bounds(plot)
500        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
501        bounds = [x0, x1, y0, y1]
502        periodic = int(any(plot.data.ds.periodicity))
503
504        def transform(field_name, vector_value):
505            field_units = plot.data[field_name].units
506
507            def _transformed_field(field, data):
508                return data[field_name] - data.ds.arr(vector_value, field_units)
509
510            plot.data.ds.add_field(
511                ("gas", f"transformed_{field_name}"),
512                sampling_type="cell",
513                function=_transformed_field,
514                units=field_units,
515                display_field=False,
516            )
517
518        if self.bv_x != 0.0 or self.bv_x != 0.0:
519            # We create a relative vector field
520            transform(self.field_x, self.bv_x)
521            transform(self.field_y, self.bv_y)
522            field_x = f"transformed_{self.field_x}"
523            field_y = f"transformed_{self.field_y}"
524        else:
525            field_x, field_y = self.field_x, self.field_y
526
527        # We are feeding this size into the pixelizer, where it will properly
528        # set it in reverse order
529        nx = plot.image._A.shape[1] // self.factor
530        ny = plot.image._A.shape[0] // self.factor
531        pixX = plot.data.ds.coordinates.pixelize(
532            plot.data.axis,
533            plot.data,
534            field_x,
535            bounds,
536            (nx, ny),
537            False,  # antialias
538            periodic,
539        )
540        pixY = plot.data.ds.coordinates.pixelize(
541            plot.data.axis,
542            plot.data,
543            field_y,
544            bounds,
545            (nx, ny),
546            False,  # antialias
547            periodic,
548        )
549        X, Y = np.meshgrid(
550            np.linspace(xx0, xx1, nx, endpoint=True),
551            np.linspace(yy0, yy1, ny, endpoint=True),
552        )
553        if self.normalize:
554            nn = np.sqrt(pixX ** 2 + pixY ** 2)
555            pixX /= nn
556            pixY /= nn
557        plot._axes.quiver(
558            X,
559            Y,
560            pixX,
561            pixY,
562            scale=self.scale,
563            scale_units=self.scale_units,
564            **self.plot_args,
565        )
566        plot._axes.set_xlim(xx0, xx1)
567        plot._axes.set_ylim(yy0, yy1)
568
569
570class ContourCallback(PlotCallback):
571    """
572    Add contours in *field* to the plot.  *ncont* governs the number of
573    contours generated, *factor* governs the number of points used in the
574    interpolation, *take_log* governs how it is contoured and *clim* gives
575    the (upper, lower) limits for contouring.  An alternate data source can be
576    specified with *data_source*, but by default the plot's data source will be
577    queried.
578    """
579
580    _type_name = "contour"
581    _supported_geometries = ("cartesian", "spectral_cube", "cylindrical")
582
583    def __init__(
584        self,
585        field,
586        ncont=5,
587        factor=4,
588        clim=None,
589        plot_args=None,
590        label=False,
591        take_log=None,
592        label_args=None,
593        text_args=None,
594        data_source=None,
595    ):
596        PlotCallback.__init__(self)
597        def_plot_args = {"colors": "k", "linestyles": "solid"}
598        def_text_args = {"colors": "w"}
599        self.ncont = ncont
600        self.field = field
601        self.factor = factor
602        self.clim = clim
603        self.take_log = take_log
604        if plot_args is None:
605            plot_args = def_plot_args
606        self.plot_args = plot_args
607        self.label = label
608        if label_args is not None:
609            text_args = label_args
610            warnings.warn(
611                "The label_args keyword is deprecated.  Please use "
612                "the text_args keyword instead."
613            )
614        if text_args is None:
615            text_args = def_text_args
616        self.text_args = text_args
617        self.data_source = data_source
618
619    def __call__(self, plot):
620        from matplotlib.tri import LinearTriInterpolator, Triangulation
621
622        # These need to be in code_length
623        x0, x1, y0, y1 = self._physical_bounds(plot)
624
625        # These are in plot coordinates, which may not be code coordinates.
626        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
627
628        # See the note about rows/columns in the pixelizer for more information
629        # on why we choose the bounds we do
630        numPoints_x = plot.image._A.shape[1]
631        numPoints_y = plot.image._A.shape[0]
632
633        # Multiply by dx and dy to go from data->plot
634        dx = (xx1 - xx0) / (x1 - x0)
635        dy = (yy1 - yy0) / (y1 - y0)
636
637        # We want xi, yi in plot coordinates
638        xi, yi = np.mgrid[
639            xx0 : xx1 : numPoints_x / (self.factor * 1j),
640            yy0 : yy1 : numPoints_y / (self.factor * 1j),
641        ]
642        data = self.data_source or plot.data
643
644        if plot._type_name in ["CuttingPlane", "Projection", "Slice"]:
645            if plot._type_name == "CuttingPlane":
646                x = data["px"] * dx
647                y = data["py"] * dy
648                z = data[self.field]
649            elif plot._type_name in ["Projection", "Slice"]:
650                # Makes a copy of the position fields "px" and "py" and adds the
651                # appropriate shift to the copied field.
652
653                AllX = np.zeros(data["px"].size, dtype="bool")
654                AllY = np.zeros(data["py"].size, dtype="bool")
655                XShifted = data["px"].copy()
656                YShifted = data["py"].copy()
657                dom_x, dom_y = plot._period
658                for shift in np.mgrid[-1:1:3j]:
659                    xlim = (data["px"] + shift * dom_x >= x0) & (
660                        data["px"] + shift * dom_x <= x1
661                    )
662                    ylim = (data["py"] + shift * dom_y >= y0) & (
663                        data["py"] + shift * dom_y <= y1
664                    )
665                    XShifted[xlim] += shift * dom_x
666                    YShifted[ylim] += shift * dom_y
667                    AllX |= xlim
668                    AllY |= ylim
669
670                # At this point XShifted and YShifted are the shifted arrays of
671                # position data in data coordinates
672                wI = AllX & AllY
673
674                # This converts XShifted and YShifted into plot coordinates
675                x = ((XShifted[wI] - x0) * dx).ndarray_view() + xx0
676                y = ((YShifted[wI] - y0) * dy).ndarray_view() + yy0
677                z = data[self.field][wI]
678
679            # Both the input and output from the triangulator are in plot
680            # coordinates
681            triangulation = Triangulation(x, y)
682            zi = LinearTriInterpolator(triangulation, z)(xi, yi)
683        elif plot._type_name == "OffAxisProjection":
684            zi = plot.frb[self.field][:: self.factor, :: self.factor].transpose()
685
686        if self.take_log is None:
687            field = data._determine_fields([self.field])[0]
688            self.take_log = plot.ds._get_field_info(*field).take_log
689
690        if self.take_log:
691            zi = np.log10(zi)
692
693        if self.take_log and self.clim is not None:
694            self.clim = (np.log10(self.clim[0]), np.log10(self.clim[1]))
695
696        if self.clim is not None:
697            self.ncont = np.linspace(self.clim[0], self.clim[1], self.ncont)
698
699        cset = plot._axes.contour(xi, yi, zi, self.ncont, **self.plot_args)
700        plot._axes.set_xlim(xx0, xx1)
701        plot._axes.set_ylim(yy0, yy1)
702
703        if self.label:
704            plot._axes.clabel(cset, **self.text_args)
705
706
707class GridBoundaryCallback(PlotCallback):
708    """
709    Draws grids on an existing PlotWindow object.  Adds grid boundaries to a
710    plot, optionally with alpha-blending. By default, colors different levels of
711    grids with different colors going from white to black, but you can change to
712    any arbitrary colormap with cmap keyword, to all black grid edges for all
713    levels with cmap=None and edgecolors=None, or to an arbitrary single color
714    for grid edges with edgecolors='YourChosenColor' defined in any of the
715    standard ways (e.g., edgecolors='white', edgecolors='r',
716    edgecolors='#00FFFF', or edgecolor='0.3', where the last is a float in 0-1
717    scale indicating gray).  Note that setting edgecolors overrides cmap if you
718    have both set to non-None values.  Cutoff for display is at min_pix
719    wide. draw_ids puts the grid id a the corner of the grid (but its not so
720    great in projections...).  id_loc determines which corner holds the grid id.
721    One can set min and maximum level of grids to display, and
722    can change the linewidth of the displayed grids.
723    """
724
725    _type_name = "grids"
726    _supported_geometries = ("cartesian", "spectral_cube", "cylindrical")
727
728    def __init__(
729        self,
730        alpha=0.7,
731        min_pix=1,
732        min_pix_ids=20,
733        draw_ids=False,
734        id_loc=None,
735        periodic=True,
736        min_level=None,
737        max_level=None,
738        cmap="B-W LINEAR_r",
739        edgecolors=None,
740        linewidth=1.0,
741    ):
742        PlotCallback.__init__(self)
743        self.alpha = alpha
744        self.min_pix = min_pix
745        self.min_pix_ids = min_pix_ids
746        self.draw_ids = draw_ids  # put grid numbers in the corner.
747        if id_loc is None:
748            self.id_loc = "lower left"
749        else:
750            self.id_loc = id_loc.lower()  # Make case-insensitive
751            if not self.draw_ids:
752                mylog.warning(
753                    "Supplied id_loc but draw_ids is False. Not drawing grid ids"
754                )
755        self.periodic = periodic
756        self.min_level = min_level
757        self.max_level = max_level
758        self.linewidth = linewidth
759        self.cmap = cmap
760        self.edgecolors = edgecolors
761
762    def __call__(self, plot):
763        if plot.data.ds.geometry == "cylindrical" and plot.data.ds.dimensionality == 3:
764            raise NotImplementedError(
765                "Grid annotation is only supported for \
766                for 2D cylindrical geometry, not 3D"
767            )
768        from matplotlib.colors import colorConverter
769
770        x0, x1, y0, y1 = self._physical_bounds(plot)
771        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
772        (dx, dy) = self._pixel_scale(plot)
773        (ypix, xpix) = plot.image._A.shape
774        ax = plot.data.axis
775        px_index = plot.data.ds.coordinates.x_axis[ax]
776        py_index = plot.data.ds.coordinates.y_axis[ax]
777        DW = plot.data.ds.domain_width
778        if self.periodic:
779            pxs, pys = np.mgrid[-1:1:3j, -1:1:3j]
780        else:
781            pxs, pys = np.mgrid[0:0:1j, 0:0:1j]
782        GLE, GRE, levels, block_ids = [], [], [], []
783        for block, _mask in plot.data.blocks:
784            GLE.append(block.LeftEdge.in_units("code_length"))
785            GRE.append(block.RightEdge.in_units("code_length"))
786            levels.append(block.Level)
787            block_ids.append(block.id)
788        if len(GLE) == 0:
789            return
790        # Retain both units and registry
791        GLE = plot.ds.arr(GLE, units=GLE[0].units)
792        GRE = plot.ds.arr(GRE, units=GRE[0].units)
793        levels = np.array(levels)
794        min_level = self.min_level or 0
795        max_level = self.max_level or levels.max()
796
797        # sort the four arrays in order of ascending level, this makes images look nicer
798        new_indices = np.argsort(levels)
799        levels = levels[new_indices]
800        GLE = GLE[new_indices]
801        GRE = GRE[new_indices]
802        block_ids = np.array(block_ids)[new_indices]
803
804        for px_off, py_off in zip(pxs.ravel(), pys.ravel()):
805            pxo = px_off * DW[px_index]
806            pyo = py_off * DW[py_index]
807            left_edge_x = np.array((GLE[:, px_index] + pxo - x0) * dx) + xx0
808            left_edge_y = np.array((GLE[:, py_index] + pyo - y0) * dy) + yy0
809            right_edge_x = np.array((GRE[:, px_index] + pxo - x0) * dx) + xx0
810            right_edge_y = np.array((GRE[:, py_index] + pyo - y0) * dy) + yy0
811            xwidth = xpix * (right_edge_x - left_edge_x) / (xx1 - xx0)
812            ywidth = ypix * (right_edge_y - left_edge_y) / (yy1 - yy0)
813            visible = np.logical_and(
814                np.logical_and(xwidth > self.min_pix, ywidth > self.min_pix),
815                np.logical_and(levels >= min_level, levels <= max_level),
816            )
817
818            # Grids can either be set by edgecolors OR a colormap.
819            if self.edgecolors is not None:
820                edgecolors = colorConverter.to_rgba(self.edgecolors, alpha=self.alpha)
821            else:  # use colormap if not explicitly overridden by edgecolors
822                if self.cmap is not None:
823                    color_bounds = [0, plot.data.ds.index.max_level]
824                    edgecolors = (
825                        apply_colormap(
826                            levels[visible] * 1.0,
827                            color_bounds=color_bounds,
828                            cmap_name=self.cmap,
829                        )[0, :, :]
830                        * 1.0
831                        / 255.0
832                    )
833                    edgecolors[:, 3] = self.alpha
834                else:
835                    edgecolors = (0.0, 0.0, 0.0, self.alpha)
836
837            if visible.nonzero()[0].size == 0:
838                continue
839            verts = np.array(
840                [
841                    (left_edge_x, left_edge_x, right_edge_x, right_edge_x),
842                    (left_edge_y, right_edge_y, right_edge_y, left_edge_y),
843                ]
844            )
845            verts = verts.transpose()[visible, :, :]
846            grid_collection = matplotlib.collections.PolyCollection(
847                verts,
848                facecolors="none",
849                edgecolors=edgecolors,
850                linewidth=self.linewidth,
851            )
852            plot._axes.add_collection(grid_collection)
853
854            visible_ids = np.logical_and(
855                np.logical_and(xwidth > self.min_pix_ids, ywidth > self.min_pix_ids),
856                np.logical_and(levels >= min_level, levels <= max_level),
857            )
858
859            if self.draw_ids:
860                plot_ids = np.where(visible_ids)[0]
861                x = np.empty(plot_ids.size)
862                y = np.empty(plot_ids.size)
863                for i, n in enumerate(plot_ids):
864                    if self.id_loc == "lower left":
865                        x[i] = left_edge_x[n] + (2 * (xx1 - xx0) / xpix)
866                        y[i] = left_edge_y[n] + (2 * (yy1 - yy0) / ypix)
867                    elif self.id_loc == "lower right":
868                        x[i] = right_edge_x[n] - (
869                            (10 * len(str(block_ids[i])) - 2) * (xx1 - xx0) / xpix
870                        )
871                        y[i] = left_edge_y[n] + (2 * (yy1 - yy0) / ypix)
872                    elif self.id_loc == "upper left":
873                        x[i] = left_edge_x[n] + (2 * (xx1 - xx0) / xpix)
874                        y[i] = right_edge_y[n] - (12 * (yy1 - yy0) / ypix)
875                    elif self.id_loc == "upper right":
876                        x[i] = right_edge_x[n] - (
877                            (10 * len(str(block_ids[i])) - 2) * (xx1 - xx0) / xpix
878                        )
879                        y[i] = right_edge_y[n] - (12 * (yy1 - yy0) / ypix)
880                    else:
881                        raise RuntimeError(
882                            "Unrecognized id_loc value ('%s'). "
883                            "Allowed values are 'lower left', lower right', "
884                            "'upper left', and 'upper right'." % self.id_loc
885                        )
886                    plot._axes.text(x[i], y[i], "%d" % block_ids[n], clip_on=True)
887
888
889class StreamlineCallback(PlotCallback):
890    """
891    Add streamlines to any plot, using the *field_x* and *field_y*
892    from the associated data, skipping every *factor* datapoints like
893    'quiver'. *density* is the index of the amount of the streamlines.
894    *field_color* is a field to be used to colormap the streamlines.
895    If *display_threshold* is supplied, any streamline segments where
896    *field_color* is less than the threshold will be removed by having
897    their line width set to 0.
898    """
899
900    _type_name = "streamlines"
901    _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical")
902
903    def __init__(
904        self,
905        field_x,
906        field_y,
907        factor=16,
908        density=1,
909        field_color=None,
910        display_threshold=None,
911        plot_args=None,
912    ):
913        PlotCallback.__init__(self)
914        def_plot_args = {}
915        self.field_x = field_x
916        self.field_y = field_y
917        self.field_color = field_color
918        self.factor = factor
919        self.dens = density
920        self.display_threshold = display_threshold
921        if plot_args is None:
922            plot_args = def_plot_args
923        self.plot_args = plot_args
924
925    def __call__(self, plot):
926        bounds = self._physical_bounds(plot)
927        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
928
929        # We are feeding this size into the pixelizer, where it will properly
930        # set it in reverse order
931        nx = plot.image._A.shape[1] // self.factor
932        ny = plot.image._A.shape[0] // self.factor
933        pixX = plot.data.ds.coordinates.pixelize(
934            plot.data.axis, plot.data, self.field_x, bounds, (nx, ny)
935        )
936        pixY = plot.data.ds.coordinates.pixelize(
937            plot.data.axis, plot.data, self.field_y, bounds, (nx, ny)
938        )
939        if self.field_color:
940            field_colors = plot.data.ds.coordinates.pixelize(
941                plot.data.axis, plot.data, self.field_color, bounds, (nx, ny)
942            )
943
944            if self.display_threshold:
945
946                mask = field_colors > self.display_threshold
947                lwdefault = matplotlib.rcParams["lines.linewidth"]
948
949                if "linewidth" in self.plot_args:
950                    linewidth = self.plot_args["linewidth"]
951                else:
952                    linewidth = lwdefault
953
954                try:
955                    linewidth *= mask
956                    self.plot_args["linewidth"] = linewidth
957                except ValueError as e:
958                    err_msg = (
959                        "Error applying display threshold: linewidth"
960                        + "must have shape ({}, {}) or be scalar"
961                    )
962                    err_msg = err_msg.format(nx, ny)
963                    raise ValueError(err_msg) from e
964
965        else:
966            field_colors = None
967
968        X, Y = (
969            np.linspace(xx0, xx1, nx, endpoint=True),
970            np.linspace(yy0, yy1, ny, endpoint=True),
971        )
972        streamplot_args = {
973            "x": X,
974            "y": Y,
975            "u": pixX,
976            "v": pixY,
977            "density": self.dens,
978            "color": field_colors,
979        }
980        streamplot_args.update(self.plot_args)
981        plot._axes.streamplot(**streamplot_args)
982        plot._axes.set_xlim(xx0, xx1)
983        plot._axes.set_ylim(yy0, yy1)
984
985
986class LinePlotCallback(PlotCallback):
987    """
988    Overplot a line with endpoints at p1 and p2.  p1 and p2
989    should be 2D or 3D coordinates consistent with the coordinate
990    system denoted in the "coord_system" keyword.
991
992    Parameters
993    ----------
994    p1, p2 : 2- or 3-element tuples, lists, or arrays
995        These are the coordinates of the endpoints of the line.
996
997    coord_system : string, optional
998        This string defines the coordinate system of the coordinates p1 and p2.
999        Valid coordinates are:
1000
1001            "data" -- the 3D dataset coordinates
1002
1003            "plot" -- the 2D coordinates defined by the actual plot limits
1004
1005            "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is
1006                      upper right
1007
1008            "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1)
1009                        is upper right
1010
1011    plot_args : dictionary, optional
1012        This dictionary is passed to the MPL plot function for generating
1013        the line.  By default, it is: {'color':'white', 'linewidth':2}
1014
1015    Examples
1016    --------
1017
1018    >>> # Overplot a diagonal white line from the lower left corner to upper
1019    >>> # right corner
1020    >>> import yt
1021    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1022    >>> s = yt.SlicePlot(ds, "z", "density")
1023    >>> s.annotate_line([0, 0], [1, 1], coord_system="axis")
1024    >>> s.save()
1025
1026    >>> # Overplot a red dashed line from data coordinate (0.1, 0.2, 0.3) to
1027    >>> # (0.5, 0.6, 0.7)
1028    >>> import yt
1029    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1030    >>> s = yt.SlicePlot(ds, "z", "density")
1031    >>> s.annotate_line(
1032    ...     [0.1, 0.2, 0.3],
1033    ...     [0.5, 0.6, 0.7],
1034    ...     coord_system="data",
1035    ...     plot_args={"color": "red", "lineStyles": "--"},
1036    ... )
1037    >>> s.save()
1038
1039    """
1040
1041    _type_name = "line"
1042    _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical")
1043
1044    def __init__(self, p1, p2, data_coords=False, coord_system="data", plot_args=None):
1045        PlotCallback.__init__(self)
1046        def_plot_args = {"color": "white", "linewidth": 2}
1047        self.p1 = p1
1048        self.p2 = p2
1049        if plot_args is None:
1050            plot_args = def_plot_args
1051        self.plot_args = plot_args
1052        if data_coords:
1053            coord_system = "data"
1054            warnings.warn(
1055                "The data_coords keyword is deprecated.  Please set "
1056                "the keyword coord_system='data' instead."
1057            )
1058        self.coord_system = coord_system
1059        self.transform = None
1060
1061    def __call__(self, plot):
1062        p1 = self._sanitize_coord_system(plot, self.p1, coord_system=self.coord_system)
1063        p2 = self._sanitize_coord_system(plot, self.p2, coord_system=self.coord_system)
1064        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
1065        plot._axes.plot(
1066            [p1[0], p2[0]], [p1[1], p2[1]], transform=self.transform, **self.plot_args
1067        )
1068        plot._axes.set_xlim(xx0, xx1)
1069        plot._axes.set_ylim(yy0, yy1)
1070
1071
1072class ImageLineCallback(LinePlotCallback):
1073    """
1074    This callback is deprecated, as it is simply a wrapper around
1075    the LinePlotCallback (ie annotate_image()).  The only difference is
1076    that it uses coord_system="axis" by default. Please see LinePlotCallback
1077    for more information.
1078
1079    """
1080
1081    _type_name = "image_line"
1082    _supported_geometries = ("cartesian", "spectral_cube", "cylindrical")
1083
1084    def __init__(self, p1, p2, data_coords=False, coord_system="axis", plot_args=None):
1085        super().__init__(p1, p2, data_coords, coord_system, plot_args)
1086        warnings.warn(
1087            "The ImageLineCallback (annotate_image_line()) is "
1088            "deprecated.  Please use the LinePlotCallback "
1089            "(annotate_line()) instead."
1090        )
1091
1092    def __call__(self, plot):
1093        super().__call__(plot)
1094
1095
1096class CuttingQuiverCallback(PlotCallback):
1097    """
1098    Get a quiver plot on top of a cutting plane, using *field_x* and
1099    *field_y*, skipping every *factor* datapoint in the discretization.
1100    *scale* is the data units per arrow length unit using *scale_units*
1101    and *plot_args* allows you to pass in matplotlib arguments (see
1102    matplotlib.axes.Axes.quiver for more info). if *normalize* is True,
1103    the fields will be scaled by their local (in-plane) length, allowing
1104    morphological features to be more clearly seen for fields with
1105    substantial variation in field strength.
1106    """
1107
1108    _type_name = "cquiver"
1109    _supported_geometries = ("cartesian", "spectral_cube")
1110
1111    def __init__(
1112        self,
1113        field_x,
1114        field_y,
1115        factor=16,
1116        scale=None,
1117        scale_units=None,
1118        normalize=False,
1119        plot_args=None,
1120    ):
1121        PlotCallback.__init__(self)
1122        self.field_x = field_x
1123        self.field_y = field_y
1124        self.factor = factor
1125        self.scale = scale
1126        self.scale_units = scale_units
1127        self.normalize = normalize
1128        if plot_args is None:
1129            plot_args = {}
1130        self.plot_args = plot_args
1131
1132    def __call__(self, plot):
1133        x0, x1, y0, y1 = self._physical_bounds(plot)
1134        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
1135        nx = plot.image._A.shape[1] // self.factor
1136        ny = plot.image._A.shape[0] // self.factor
1137        indices = np.argsort(plot.data["index", "dx"])[::-1].astype(np.int_)
1138
1139        pixX = np.zeros((ny, nx), dtype="f8")
1140        pixY = np.zeros((ny, nx), dtype="f8")
1141        pixelize_off_axis_cartesian(
1142            pixX,
1143            plot.data[("index", "x")].to("code_length"),
1144            plot.data[("index", "y")].to("code_length"),
1145            plot.data[("index", "z")].to("code_length"),
1146            plot.data["px"],
1147            plot.data["py"],
1148            plot.data["pdx"],
1149            plot.data["pdy"],
1150            plot.data["pdz"],
1151            plot.data.center,
1152            plot.data._inv_mat,
1153            indices,
1154            plot.data[self.field_x],
1155            (x0, x1, y0, y1),
1156        )
1157        pixelize_off_axis_cartesian(
1158            pixY,
1159            plot.data[("index", "x")].to("code_length"),
1160            plot.data[("index", "y")].to("code_length"),
1161            plot.data[("index", "z")].to("code_length"),
1162            plot.data["px"],
1163            plot.data["py"],
1164            plot.data["pdx"],
1165            plot.data["pdy"],
1166            plot.data["pdz"],
1167            plot.data.center,
1168            plot.data._inv_mat,
1169            indices,
1170            plot.data[self.field_y],
1171            (x0, x1, y0, y1),
1172        )
1173        X, Y = np.meshgrid(
1174            np.linspace(xx0, xx1, nx, endpoint=True),
1175            np.linspace(yy0, yy1, ny, endpoint=True),
1176        )
1177
1178        if self.normalize:
1179            nn = np.sqrt(pixX ** 2 + pixY ** 2)
1180            pixX /= nn
1181            pixY /= nn
1182
1183        plot._axes.quiver(
1184            X,
1185            Y,
1186            pixX,
1187            pixY,
1188            scale=self.scale,
1189            scale_units=self.scale_units,
1190            **self.plot_args,
1191        )
1192        plot._axes.set_xlim(xx0, xx1)
1193        plot._axes.set_ylim(yy0, yy1)
1194
1195
1196class ClumpContourCallback(PlotCallback):
1197    """
1198    Take a list of *clumps* and plot them as a set of contours.
1199    """
1200
1201    _type_name = "clumps"
1202    _supported_geometries = ("cartesian", "spectral_cube", "cylindrical")
1203
1204    def __init__(self, clumps, plot_args=None):
1205        self.clumps = clumps
1206        if plot_args is None:
1207            plot_args = {}
1208        if "color" in plot_args:
1209            plot_args["colors"] = plot_args.pop("color")
1210        self.plot_args = plot_args
1211
1212    def __call__(self, plot):
1213        bounds = self._physical_bounds(plot)
1214        extent = self._plot_bounds(plot)
1215
1216        ax = plot.data.axis
1217        px_index = plot.data.ds.coordinates.x_axis[ax]
1218        py_index = plot.data.ds.coordinates.y_axis[ax]
1219
1220        xf = plot.data.ds.coordinates.axis_name[px_index]
1221        yf = plot.data.ds.coordinates.axis_name[py_index]
1222        dxf = f"d{xf}"
1223        dyf = f"d{yf}"
1224
1225        ny, nx = plot.image._A.shape
1226        buff = np.zeros((nx, ny), dtype="float64")
1227        for i, clump in enumerate(reversed(self.clumps)):
1228            mylog.info("Pixelizing contour %s", i)
1229
1230            if isinstance(clump, Clump):
1231                ftype = "index"
1232            elif isinstance(clump, YTClumpContainer):
1233                ftype = "grid"
1234            else:
1235                raise RuntimeError(
1236                    f"Unknown field type for object of type {type(clump)}."
1237                )
1238
1239            xf_copy = clump[ftype, xf].copy().in_units("code_length")
1240            yf_copy = clump[ftype, yf].copy().in_units("code_length")
1241
1242            temp = np.zeros((ny, nx), dtype="f8")
1243            pixelize_cartesian(
1244                temp,
1245                xf_copy,
1246                yf_copy,
1247                clump[ftype, dxf].in_units("code_length") / 2.0,
1248                clump[ftype, dyf].in_units("code_length") / 2.0,
1249                clump[ftype, dxf].d * 0.0 + i + 1,  # inits inside Pixelize
1250                bounds,
1251                0,
1252            )
1253            buff = np.maximum(temp, buff)
1254        self.rv = plot._axes.contour(
1255            buff, np.unique(buff), extent=extent, **self.plot_args
1256        )
1257
1258
1259class ArrowCallback(PlotCallback):
1260    """
1261    Overplot arrow(s) pointing at position(s) for highlighting specific
1262    features.  By default, arrow points from lower left to the designated
1263    position "pos" with arrow length "length".  Alternatively, if
1264    "starting_pos" is set, arrow will stretch from "starting_pos" to "pos"
1265    and "length" will be disregarded.
1266
1267    "coord_system" keyword refers to positions set in "pos" arg and
1268    "starting_pos" keyword, which by default are in data coordinates.
1269
1270    "length", "width", "head_length", and "head_width" keywords for the arrow
1271    are all in axis units, ie relative to the size of the plot axes as 1,
1272    even if the position of the arrow is set relative to another coordinate
1273    system.
1274
1275    Parameters
1276    ----------
1277    pos : array-like
1278        These are the coordinates where the marker(s) will be overplotted
1279        Either as [x,y,z] or as [[x1,x2,...],[y1,y2,...],[z1,z2,...]]
1280
1281    length : float, optional
1282        The length, in axis units, of the arrow.
1283        Default: 0.03
1284
1285    width : float, optional
1286        The width, in axis units, of the tail line of the arrow.
1287        Default: 0.003
1288
1289    head_length : float, optional
1290        The length, in axis units, of the head of the arrow.  If set
1291        to None, use 1.5*head_width
1292        Default: None
1293
1294    head_width : float, optional
1295        The width, in axis units, of the head of the arrow.
1296        Default: 0.02
1297
1298    starting_pos : 2- or 3-element tuple, list, or array, optional
1299        These are the coordinates from which the arrow starts towards its
1300        point.  Not compatible with 'length' kwarg.
1301
1302    coord_system : string, optional
1303        This string defines the coordinate system of the coordinates of pos
1304        Valid coordinates are:
1305
1306            "data" -- the 3D dataset coordinates
1307
1308            "plot" -- the 2D coordinates defined by the actual plot limits
1309
1310            "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is
1311                      upper right
1312
1313            "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1)
1314                        is upper right
1315
1316    plot_args : dictionary, optional
1317        This dictionary is passed to the MPL arrow function for generating
1318        the arrow.  By default, it is: {'color':'white'}
1319
1320    Examples
1321    --------
1322
1323    >>> # Overplot an arrow pointing to feature at data coord: (0.2, 0.3, 0.4)
1324    >>> import yt
1325    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1326    >>> s = yt.SlicePlot(ds, "z", "density")
1327    >>> s.annotate_arrow([0.2, 0.3, 0.4])
1328    >>> s.save()
1329
1330    >>> # Overplot a red arrow with longer length pointing to plot coordinate
1331    >>> # (0.1, -0.1)
1332    >>> import yt
1333    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1334    >>> s = yt.SlicePlot(ds, "z", "density")
1335    >>> s.annotate_arrow(
1336    ...     [0.1, -0.1], length=0.06, coord_system="plot", plot_args={"color": "red"}
1337    ... )
1338    >>> s.save()
1339
1340    """
1341
1342    _type_name = "arrow"
1343    _supported_geometries = ("cartesian", "spectral_cube", "cylindrical")
1344
1345    def __init__(
1346        self,
1347        pos,
1348        code_size=None,
1349        length=0.03,
1350        width=0.0001,
1351        head_width=0.01,
1352        head_length=0.01,
1353        starting_pos=None,
1354        coord_system="data",
1355        plot_args=None,
1356    ):
1357        def_plot_args = {"color": "white"}
1358        self.pos = pos
1359        self.code_size = code_size
1360        self.length = length
1361        self.width = width
1362        self.head_width = head_width
1363        self.head_length = head_length
1364        self.starting_pos = starting_pos
1365        self.coord_system = coord_system
1366        self.transform = None
1367        if plot_args is None:
1368            plot_args = def_plot_args
1369        self.plot_args = plot_args
1370
1371    def __call__(self, plot):
1372        x, y = self._sanitize_coord_system(
1373            plot, self.pos, coord_system=self.coord_system
1374        )
1375        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
1376        # normalize all of the kwarg lengths to the plot size
1377        plot_diag = ((yy1 - yy0) ** 2 + (xx1 - xx0) ** 2) ** (0.5)
1378        self.length *= plot_diag
1379        self.width *= plot_diag
1380        self.head_width *= plot_diag
1381        if self.head_length is not None:
1382            self.head_length *= plot_diag
1383        if self.code_size is not None:
1384            warnings.warn(
1385                "The code_size keyword is deprecated.  Please use "
1386                "the length keyword in 'axis' units instead. "
1387                "Setting code_size overrides length value."
1388            )
1389            if is_sequence(self.code_size):
1390                self.code_size = plot.data.ds.quan(self.code_size[0], self.code_size[1])
1391                self.code_size = np.float64(self.code_size.in_units(plot.xlim[0].units))
1392            self.code_size = self.code_size * self._pixel_scale(plot)[0]
1393            dx = dy = self.code_size
1394        else:
1395            if self.starting_pos is not None:
1396                start_x, start_y = self._sanitize_coord_system(
1397                    plot, self.starting_pos, coord_system=self.coord_system
1398                )
1399                dx = x - start_x
1400                dy = y - start_y
1401            else:
1402                dx = (xx1 - xx0) * 2 ** (0.5) * self.length
1403                dy = (yy1 - yy0) * 2 ** (0.5) * self.length
1404        # If the arrow is 0 length
1405        if dx == dy == 0:
1406            warnings.warn("The arrow has zero length.  Not annotating.")
1407            return
1408        try:
1409            plot._axes.arrow(
1410                x - dx,
1411                y - dy,
1412                dx,
1413                dy,
1414                width=self.width,
1415                head_width=self.head_width,
1416                head_length=self.head_length,
1417                transform=self.transform,
1418                length_includes_head=True,
1419                **self.plot_args,
1420            )
1421        except ValueError:
1422            for i in range(len(x)):
1423                plot._axes.arrow(
1424                    x[i] - dx,
1425                    y[i] - dy,
1426                    dx,
1427                    dy,
1428                    width=self.width,
1429                    head_width=self.head_width,
1430                    head_length=self.head_length,
1431                    transform=self.transform,
1432                    length_includes_head=True,
1433                    **self.plot_args,
1434                )
1435        plot._axes.set_xlim(xx0, xx1)
1436        plot._axes.set_ylim(yy0, yy1)
1437
1438
1439class MarkerAnnotateCallback(PlotCallback):
1440    """
1441    Overplot marker(s) at a position(s) for highlighting specific features.
1442
1443    Parameters
1444    ----------
1445    pos : array-like
1446        These are the coordinates where the marker(s) will be overplotted
1447        Either as [x,y,z] or as [[x1,x2,...],[y1,y2,...],[z1,z2,...]]
1448
1449    marker : string, optional
1450        The shape of the marker to be passed to the MPL scatter function.
1451        By default, it is 'x', but other acceptable values are: '.', 'o', 'v',
1452        '^', 's', 'p' '*', etc.  See matplotlib.markers for more information.
1453
1454    coord_system : string, optional
1455        This string defines the coordinate system of the coordinates of pos
1456        Valid coordinates are:
1457
1458            "data" -- the 3D dataset coordinates
1459
1460            "plot" -- the 2D coordinates defined by the actual plot limits
1461
1462            "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is
1463                      upper right
1464
1465            "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1)
1466                        is upper right
1467
1468    plot_args : dictionary, optional
1469        This dictionary is passed to the MPL scatter function for generating
1470        the marker.  By default, it is: {'color':'white', 's':50}
1471
1472    Examples
1473    --------
1474
1475    >>> # Overplot a white X on a feature at data location (0.5, 0.5, 0.5)
1476    >>> import yt
1477    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1478    >>> s = yt.SlicePlot(ds, "z", "density")
1479    >>> s.annotate_marker([0.4, 0.5, 0.6])
1480    >>> s.save()
1481
1482    >>> # Overplot a big yellow circle at axis location (0.1, 0.2)
1483    >>> import yt
1484    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1485    >>> s = yt.SlicePlot(ds, "z", "density")
1486    >>> s.annotate_marker(
1487    ...     [0.1, 0.2],
1488    ...     marker="o",
1489    ...     coord_system="axis",
1490    ...     plot_args={"color": "yellow", "s": 200},
1491    ... )
1492    >>> s.save()
1493
1494    """
1495
1496    _type_name = "marker"
1497    _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical")
1498
1499    def __init__(self, pos, marker="x", coord_system="data", plot_args=None):
1500        def_plot_args = {"color": "w", "s": 50}
1501        self.pos = pos
1502        self.marker = marker
1503        if plot_args is None:
1504            plot_args = def_plot_args
1505        self.plot_args = plot_args
1506        self.coord_system = coord_system
1507        self.transform = None
1508
1509    def __call__(self, plot):
1510        x, y = self._sanitize_coord_system(
1511            plot, self.pos, coord_system=self.coord_system
1512        )
1513        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
1514        plot._axes.scatter(
1515            x, y, marker=self.marker, transform=self.transform, **self.plot_args
1516        )
1517        plot._axes.set_xlim(xx0, xx1)
1518        plot._axes.set_ylim(yy0, yy1)
1519
1520
1521class SphereCallback(PlotCallback):
1522    """
1523    Overplot a circle with designated center and radius with optional text.
1524
1525    Parameters
1526    ----------
1527    center : 2- or 3-element tuple, list, or array
1528        These are the coordinates where the circle will be overplotted
1529
1530    radius : YTArray, float, or (1, ('kpc')) style tuple
1531        The radius of the circle in code coordinates
1532
1533    circle_args : dict, optional
1534        This dictionary is passed to the MPL circle object. By default,
1535        {'color':'white'}
1536
1537    coord_system : string, optional
1538        This string defines the coordinate system of the coordinates of pos
1539        Valid coordinates are:
1540
1541            "data" -- the 3D dataset coordinates
1542
1543            "plot" -- the 2D coordinates defined by the actual plot limits
1544
1545            "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is
1546                      upper right
1547
1548            "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1)
1549                        is upper right
1550
1551    text : string, optional
1552        Optional text to include next to the circle.
1553
1554    text_args : dictionary, optional
1555        This dictionary is passed to the MPL text function. By default,
1556        it is: {'color':'white'}
1557
1558    Examples
1559    --------
1560
1561    >>> # Overplot a white circle of radius 100 kpc over the central galaxy
1562    >>> import yt
1563    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1564    >>> s = yt.SlicePlot(ds, "z", "density")
1565    >>> s.annotate_sphere([0.5, 0.5, 0.5], radius=(100, "kpc"))
1566    >>> s.save()
1567
1568    """
1569
1570    _type_name = "sphere"
1571    _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical")
1572
1573    def __init__(
1574        self,
1575        center,
1576        radius,
1577        circle_args=None,
1578        text=None,
1579        coord_system="data",
1580        text_args=None,
1581    ):
1582        def_text_args = {"color": "white"}
1583        def_circle_args = {"color": "white"}
1584        self.center = center
1585        self.radius = radius
1586        if circle_args is None:
1587            circle_args = def_circle_args
1588        if "fill" not in circle_args:
1589            circle_args["fill"] = False
1590        self.circle_args = circle_args
1591        self.text = text
1592        if text_args is None:
1593            text_args = def_text_args
1594        self.text_args = text_args
1595        self.coord_system = coord_system
1596        self.transform = None
1597
1598    def __call__(self, plot):
1599        from matplotlib.patches import Circle
1600
1601        if is_sequence(self.radius):
1602            self.radius = plot.data.ds.quan(self.radius[0], self.radius[1])
1603            self.radius = np.float64(self.radius.in_units(plot.xlim[0].units))
1604        if isinstance(self.radius, YTQuantity):
1605            if isinstance(self.center, YTArray):
1606                units = self.center.units
1607            else:
1608                units = "code_length"
1609            self.radius = self.radius.to(units)
1610
1611        # This assures the radius has the appropriate size in
1612        # the different coordinate systems, since one cannot simply
1613        # apply a different transform for a length in the same way
1614        # you can for a coordinate.
1615        if self.coord_system == "data" or self.coord_system == "plot":
1616            self.radius = self.radius * self._pixel_scale(plot)[0]
1617        else:
1618            self.radius /= (plot.xlim[1] - plot.xlim[0]).v
1619
1620        x, y = self._sanitize_coord_system(
1621            plot, self.center, coord_system=self.coord_system
1622        )
1623
1624        cir = Circle((x, y), self.radius, transform=self.transform, **self.circle_args)
1625        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
1626
1627        plot._axes.add_patch(cir)
1628        if self.text is not None:
1629            label = plot._axes.text(
1630                x, y, self.text, transform=self.transform, **self.text_args
1631            )
1632            self._set_font_properties(plot, [label], **self.text_args)
1633
1634        plot._axes.set_xlim(xx0, xx1)
1635        plot._axes.set_ylim(yy0, yy1)
1636
1637
1638class TextLabelCallback(PlotCallback):
1639    """
1640    Overplot text on the plot at a specified position. If you desire an inset
1641    box around your text, set one with the inset_box_args dictionary
1642    keyword.
1643
1644    Parameters
1645    ----------
1646    pos : 2- or 3-element tuple, list, or array
1647        These are the coordinates where the text will be overplotted
1648
1649    text : string
1650        The text you wish to include
1651
1652    coord_system : string, optional
1653        This string defines the coordinate system of the coordinates of pos
1654        Valid coordinates are:
1655
1656            "data" -- the 3D dataset coordinates
1657
1658            "plot" -- the 2D coordinates defined by the actual plot limits
1659
1660            "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is
1661                      upper right
1662
1663            "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1)
1664                        is upper right
1665
1666    text_args : dictionary, optional
1667        This dictionary is passed to the MPL text function for generating
1668        the text.  By default, it is: {'color':'white'} and uses the defaults
1669        for the other fonts in the image.
1670
1671    inset_box_args : dictionary, optional
1672        A dictionary of any arbitrary parameters to be passed to the Matplotlib
1673        FancyBboxPatch object as the inset box around the text.  Default: {}
1674
1675    Examples
1676    --------
1677
1678    >>> # Overplot white text at data location [0.55, 0.7, 0.4]
1679    >>> import yt
1680    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1681    >>> s = yt.SlicePlot(ds, "z", "density")
1682    >>> s.annotate_text([0.55, 0.7, 0.4], "Here is a galaxy")
1683    >>> s.save()
1684
1685    >>> # Overplot yellow text at axis location [0.2, 0.8] with
1686    >>> # a shaded inset box
1687    >>> import yt
1688    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1689    >>> s = yt.SlicePlot(ds, "z", "density")
1690    >>> s.annotate_text(
1691    ...     [0.2, 0.8],
1692    ...     "Here is a galaxy",
1693    ...     coord_system="axis",
1694    ...     text_args={"color": "yellow"},
1695    ...     inset_box_args={
1696    ...         "boxstyle": "square,pad=0.3",
1697    ...         "facecolor": "black",
1698    ...         "linewidth": 3,
1699    ...         "edgecolor": "white",
1700    ...         "alpha": 0.5,
1701    ...     },
1702    ... )
1703    >>> s.save()
1704    """
1705
1706    _type_name = "text"
1707    _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical")
1708
1709    def __init__(
1710        self,
1711        pos,
1712        text,
1713        data_coords=False,
1714        coord_system="data",
1715        text_args=None,
1716        inset_box_args=None,
1717    ):
1718        def_text_args = {"color": "white"}
1719        self.pos = pos
1720        self.text = text
1721        if data_coords:
1722            coord_system = "data"
1723            warnings.warn(
1724                "The data_coords keyword is deprecated.  Please set "
1725                "the keyword coord_system='data' instead."
1726            )
1727        if text_args is None:
1728            text_args = def_text_args
1729        self.text_args = text_args
1730        self.inset_box_args = inset_box_args
1731        self.coord_system = coord_system
1732        self.transform = None
1733
1734    def __call__(self, plot):
1735        kwargs = self.text_args.copy()
1736        x, y = self._sanitize_coord_system(
1737            plot, self.pos, coord_system=self.coord_system
1738        )
1739
1740        # Set the font properties of text from this callback to be
1741        # consistent with other text labels in this figure
1742        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
1743        if self.inset_box_args is not None:
1744            kwargs["bbox"] = self.inset_box_args
1745        label = plot._axes.text(x, y, self.text, transform=self.transform, **kwargs)
1746        self._set_font_properties(plot, [label], **kwargs)
1747        plot._axes.set_xlim(xx0, xx1)
1748        plot._axes.set_ylim(yy0, yy1)
1749
1750
1751class PointAnnotateCallback(TextLabelCallback):
1752    """
1753    This callback is deprecated, as it is simply a wrapper around
1754    the TextLabelCallback (ie annotate_text()).  Please see TextLabelCallback
1755    for more information.
1756
1757    """
1758
1759    _type_name = "point"
1760    _supported_geometries = ("cartesian", "spectral_cube", "cylindrical")
1761
1762    def __init__(
1763        self,
1764        pos,
1765        text,
1766        data_coords=False,
1767        coord_system="data",
1768        text_args=None,
1769        inset_box_args=None,
1770    ):
1771        super().__init__(
1772            pos, text, data_coords, coord_system, text_args, inset_box_args
1773        )
1774        warnings.warn(
1775            "The PointAnnotateCallback (annotate_point()) is "
1776            "deprecated.  Please use the TextLabelCallback "
1777            "(annotate_point()) instead."
1778        )
1779
1780    def __call__(self, plot):
1781        super().__call__(plot)
1782
1783
1784class HaloCatalogCallback(PlotCallback):
1785    """
1786    Plots circles at the locations of all the halos
1787    in a halo catalog with radii corresponding to the
1788    virial radius of each halo.
1789
1790    Note, this functionality requires the yt_astro_analysis
1791    package. See https://yt-astro-analysis.readthedocs.io/
1792    for more information.
1793
1794    Parameters
1795    ----------
1796    halo_catalog : Dataset, DataContainer,
1797                   or ~yt.analysis_modules.halo_analysis.halo_catalog.HaloCatalog
1798        The object containing halos to be overplotted. This can
1799        be a HaloCatalog object, a loaded halo catalog dataset,
1800        or a data container from a halo catalog dataset.
1801    circle_args : list
1802        Contains the arguments controlling the
1803        appearance of the circles, supplied to the
1804        Matplotlib patch Circle.
1805    width : tuple
1806        The width over which to select halos to plot,
1807        useful when overplotting to a slice plot. Accepts
1808        a tuple in the form (1.0, 'Mpc').
1809    annotate_field : str
1810        A field contained in the
1811        halo catalog to add text to the plot near the halo.
1812        Example: annotate_field = 'particle_mass' will
1813        write the halo mass next to each halo.
1814    radius_field : str
1815        A field contained in the halo
1816        catalog to set the radius of the circle which will
1817        surround each halo. Default: 'virial_radius'.
1818    center_field_prefix : str
1819        Accepts a field prefix which will
1820        be used to find the fields containing the coordinates
1821        of the center of each halo. Ex: 'particle_position'
1822        will result in the fields 'particle_position_x' for x
1823        'particle_position_y' for y, and 'particle_position_z'
1824        for z. Default: 'particle_position'.
1825    text_args : dict
1826        Contains the arguments controlling the text
1827        appearance of the annotated field.
1828    factor : float
1829        A number the virial radius is multiplied by for
1830        plotting the circles. Ex: factor = 2.0 will plot
1831        circles with twice the radius of each halo virial radius.
1832
1833    Examples
1834    --------
1835
1836    >>> import yt
1837    >>> dds = yt.load("Enzo_64/DD0043/data0043")
1838    >>> hds = yt.load("rockstar_halos/halos_0.0.bin")
1839    >>> p = yt.ProjectionPlot(
1840    ...     dds, "x", ("gas", "density"), weight_field=("gas", "density")
1841    ... )
1842    >>> p.annotate_halos(hds)
1843    >>> p.save()
1844
1845    >>> # plot a subset of all halos
1846    >>> import yt
1847    >>> dds = yt.load("Enzo_64/DD0043/data0043")
1848    >>> hds = yt.load("rockstar_halos/halos_0.0.bin")
1849    >>> # make a region half the width of the box
1850    >>> dregion = dds.box(
1851    ...     dds.domain_center - 0.25 * dds.domain_width,
1852    ...     dds.domain_center + 0.25 * dds.domain_width,
1853    ... )
1854    >>> hregion = hds.box(
1855    ...     hds.domain_center - 0.25 * hds.domain_width,
1856    ...     hds.domain_center + 0.25 * hds.domain_width,
1857    ... )
1858    >>> p = yt.ProjectionPlot(
1859    ...     dds,
1860    ...     "x",
1861    ...     ("gas", "density"),
1862    ...     weight_field=("gas", "density"),
1863    ...     data_source=dregion,
1864    ...     width=0.5,
1865    ... )
1866    >>> p.annotate_halos(hregion)
1867    >>> p.save()
1868
1869    >>> # plot halos from a HaloCatalog
1870    >>> import yt
1871    >>> from yt.extensions.astro_analysis.halo_analysis.api import HaloCatalog
1872    >>> dds = yt.load("Enzo_64/DD0043/data0043")
1873    >>> hds = yt.load("rockstar_halos/halos_0.0.bin")
1874    >>> hc = HaloCatalog(data_ds=dds, halos_ds=hds)
1875    >>> p = yt.ProjectionPlot(
1876    ...     dds, "x", ("gas", "density"), weight_field=("gas", "density")
1877    ... )
1878    >>> p.annotate_halos(hc)
1879    >>> p.save()
1880
1881    """
1882
1883    _type_name = "halos"
1884    region = None
1885    _descriptor = None
1886    _supported_geometries = ("cartesian", "spectral_cube")
1887
1888    def __init__(
1889        self,
1890        halo_catalog,
1891        circle_args=None,
1892        circle_kwargs=None,
1893        width=None,
1894        annotate_field=None,
1895        radius_field="virial_radius",
1896        center_field_prefix="particle_position",
1897        text_args=None,
1898        font_kwargs=None,
1899        factor=1.0,
1900    ):
1901
1902        try:
1903            from yt_astro_analysis.halo_analysis.api import HaloCatalog
1904        except ImportError:
1905            HaloCatalog = NotAModule("yt_astro_analysis")
1906
1907        PlotCallback.__init__(self)
1908        def_circle_args = {"edgecolor": "white", "facecolor": "None"}
1909        def_text_args = {"color": "white"}
1910
1911        if isinstance(halo_catalog, YTDataContainer):
1912            self.halo_data = halo_catalog
1913        elif isinstance(halo_catalog, Dataset):
1914            self.halo_data = halo_catalog.all_data()
1915        elif isinstance(halo_catalog, HaloCatalog):
1916            if halo_catalog.data_source.ds == halo_catalog.halos_ds:
1917                self.halo_data = halo_catalog.data_source
1918            else:
1919                self.halo_data = halo_catalog.halos_ds.all_data()
1920        else:
1921            raise RuntimeError(
1922                "halo_catalog argument must be a HaloCatalog object, "
1923                + "a dataset, or a data container."
1924            )
1925
1926        self.width = width
1927        self.radius_field = radius_field
1928        self.center_field_prefix = center_field_prefix
1929        self.annotate_field = annotate_field
1930        if circle_kwargs is not None:
1931            circle_args = circle_kwargs
1932            warnings.warn(
1933                "The circle_kwargs keyword is deprecated.  Please "
1934                "use the circle_args keyword instead."
1935            )
1936        if font_kwargs is not None:
1937            text_args = font_kwargs
1938            warnings.warn(
1939                "The font_kwargs keyword is deprecated.  Please use "
1940                "the text_args keyword instead."
1941            )
1942        if circle_args is None:
1943            circle_args = def_circle_args
1944        self.circle_args = circle_args
1945        if text_args is None:
1946            text_args = def_text_args
1947        self.text_args = text_args
1948        self.factor = factor
1949
1950    def __call__(self, plot):
1951        from matplotlib.patches import Circle
1952
1953        data = plot.data
1954        x0, x1, y0, y1 = self._physical_bounds(plot)
1955        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
1956
1957        halo_data = self.halo_data
1958        axis_names = plot.data.ds.coordinates.axis_name
1959        xax = plot.data.ds.coordinates.x_axis[data.axis]
1960        yax = plot.data.ds.coordinates.y_axis[data.axis]
1961        field_x = f"{self.center_field_prefix}_{axis_names[xax]}"
1962        field_y = f"{self.center_field_prefix}_{axis_names[yax]}"
1963        field_z = f"{self.center_field_prefix}_{axis_names[data.axis]}"
1964
1965        # Set up scales for pixel size and original data
1966        pixel_scale = self._pixel_scale(plot)[0]
1967        units = plot.xlim[0].units
1968
1969        # Convert halo positions to code units of the plotted data
1970        # and then to units of the plotted window
1971        px = halo_data[("all", field_x)][:].in_units(units)
1972        py = halo_data[("all", field_y)][:].in_units(units)
1973
1974        xplotcenter = (plot.xlim[0] + plot.xlim[1]) / 2
1975        yplotcenter = (plot.ylim[0] + plot.ylim[1]) / 2
1976
1977        xdomaincenter = plot.ds.domain_center[xax]
1978        ydomaincenter = plot.ds.domain_center[yax]
1979
1980        xoffset = xplotcenter - xdomaincenter
1981        yoffset = yplotcenter - ydomaincenter
1982
1983        xdw = plot.ds.domain_width[xax].to(units)
1984        ydw = plot.ds.domain_width[yax].to(units)
1985
1986        modpx = np.mod(px - xoffset, xdw) + xoffset
1987        modpy = np.mod(py - yoffset, ydw) + yoffset
1988
1989        px[modpx != px] = modpx[modpx != px]
1990        py[modpy != py] = modpy[modpy != py]
1991
1992        px, py = self._convert_to_plot(plot, [px, py])
1993
1994        # Convert halo radii to a radius in pixels
1995        radius = halo_data[("all", self.radius_field)][:].in_units(units)
1996        radius = np.array(radius * pixel_scale * self.factor)
1997
1998        if self.width:
1999            pz = halo_data[("all", field_z)][:].in_units("code_length")
2000            c = data.center[data.axis]
2001
2002            # I should catch an error here if width isn't in this form
2003            # but I dont really want to reimplement get_sanitized_width...
2004            width = data.ds.arr(self.width[0], self.width[1]).in_units("code_length")
2005
2006            indices = np.where((pz > c - 0.5 * width) & (pz < c + 0.5 * width))
2007
2008            px = px[indices]
2009            py = py[indices]
2010            radius = radius[indices]
2011
2012        for x, y, r in zip(px, py, radius):
2013            plot._axes.add_artist(Circle(xy=(x, y), radius=r, **self.circle_args))
2014
2015        plot._axes.set_xlim(xx0, xx1)
2016        plot._axes.set_ylim(yy0, yy1)
2017
2018        if self.annotate_field:
2019            annotate_dat = halo_data[("all", self.annotate_field)]
2020            texts = [f"{float(dat):g}" for dat in annotate_dat]
2021            labels = []
2022            for pos_x, pos_y, t in zip(px, py, texts):
2023                labels.append(plot._axes.text(pos_x, pos_y, t, **self.text_args))
2024
2025            # Set the font properties of text from this callback to be
2026            # consistent with other text labels in this figure
2027            self._set_font_properties(plot, labels, **self.text_args)
2028
2029
2030class ParticleCallback(PlotCallback):
2031    """
2032    Adds particle positions, based on a thick slab along *axis* with a
2033    *width* along the line of sight.  *p_size* controls the number of
2034    pixels per particle, and *col* governs the color.  *ptype* will
2035    restrict plotted particles to only those that are of a given type.
2036    *alpha* determines the opacity of the marker symbol used in the scatter.
2037    An alternate data source can be specified with *data_source*, but by
2038    default the plot's data source will be queried.
2039    """
2040
2041    _type_name = "particles"
2042    region = None
2043    _descriptor = None
2044    _supported_geometries = ("cartesian", "spectral_cube", "cylindrical")
2045
2046    def __init__(
2047        self,
2048        width,
2049        p_size=1.0,
2050        col="k",
2051        marker="o",
2052        stride=1,
2053        ptype="all",
2054        minimum_mass=None,
2055        alpha=1.0,
2056        data_source=None,
2057    ):
2058        PlotCallback.__init__(self)
2059        self.width = width
2060        self.p_size = p_size
2061        self.color = col
2062        self.marker = marker
2063        self.stride = stride
2064        self.ptype = ptype
2065        self.minimum_mass = minimum_mass
2066        self.alpha = alpha
2067        self.data_source = data_source
2068        if self.minimum_mass is not None:
2069            warnings.warn(
2070                "The minimum_mass keyword is deprecated.  Please use "
2071                "an appropriate particle filter and the ptype keyword instead."
2072            )
2073
2074    def __call__(self, plot):
2075        data = plot.data
2076        if is_sequence(self.width):
2077            validate_width_tuple(self.width)
2078            self.width = plot.data.ds.quan(self.width[0], self.width[1])
2079        elif isinstance(self.width, YTQuantity):
2080            self.width = plot.data.ds.quan(self.width.value, self.width.units)
2081        else:
2082            self.width = plot.data.ds.quan(self.width, "code_length")
2083        # we construct a rectangular prism
2084        x0, x1, y0, y1 = self._physical_bounds(plot)
2085        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
2086        if isinstance(self.data_source, YTCutRegion):
2087            mylog.warning(
2088                "Parameter 'width' is ignored in annotate_particles if the "
2089                "data_source is a cut_region. "
2090                "See https://github.com/yt-project/yt/issues/1933 for further details."
2091            )
2092            self.region = self.data_source
2093        else:
2094            self.region = self._get_region((x0, x1), (y0, y1), plot.data.axis, data)
2095        ax = data.axis
2096        xax = plot.data.ds.coordinates.x_axis[ax]
2097        yax = plot.data.ds.coordinates.y_axis[ax]
2098        axis_names = plot.data.ds.coordinates.axis_name
2099        field_x = f"particle_position_{axis_names[xax]}"
2100        field_y = f"particle_position_{axis_names[yax]}"
2101        pt = self.ptype
2102        self.periodic_x = plot.data.ds.periodicity[xax]
2103        self.periodic_y = plot.data.ds.periodicity[yax]
2104        self.LE = plot.data.ds.domain_left_edge[xax], plot.data.ds.domain_left_edge[yax]
2105        self.RE = (
2106            plot.data.ds.domain_right_edge[xax],
2107            plot.data.ds.domain_right_edge[yax],
2108        )
2109        period_x = plot.data.ds.domain_width[xax]
2110        period_y = plot.data.ds.domain_width[yax]
2111        particle_x, particle_y = self._enforce_periodic(
2112            self.region[pt, field_x],
2113            self.region[pt, field_y],
2114            x0,
2115            x1,
2116            period_x,
2117            y0,
2118            y1,
2119            period_y,
2120        )
2121        gg = (
2122            (particle_x >= x0)
2123            & (particle_x <= x1)
2124            & (particle_y >= y0)
2125            & (particle_y <= y1)
2126        )
2127        if self.minimum_mass is not None:
2128            gg &= self.region[pt, "particle_mass"] >= self.minimum_mass
2129            if gg.sum() == 0:
2130                return
2131        px, py = [particle_x[gg][:: self.stride], particle_y[gg][:: self.stride]]
2132        px, py = self._convert_to_plot(plot, [px, py])
2133        plot._axes.scatter(
2134            px,
2135            py,
2136            edgecolors="None",
2137            marker=self.marker,
2138            s=self.p_size,
2139            c=self.color,
2140            alpha=self.alpha,
2141        )
2142        plot._axes.set_xlim(xx0, xx1)
2143        plot._axes.set_ylim(yy0, yy1)
2144
2145    def _enforce_periodic(
2146        self, particle_x, particle_y, x0, x1, period_x, y0, y1, period_y
2147    ):
2148        #  duplicate particles if periodic in that direction AND if the plot
2149        #  extends outside the domain boundaries.
2150        if self.periodic_x and x0 > self.RE[0]:
2151            particle_x = uhstack((particle_x, particle_x + period_x))
2152            particle_y = uhstack((particle_y, particle_y))
2153        if self.periodic_x and x1 < self.LE[0]:
2154            particle_x = uhstack((particle_x, particle_x - period_x))
2155            particle_y = uhstack((particle_y, particle_y))
2156        if self.periodic_y and y0 > self.RE[1]:
2157            particle_y = uhstack((particle_y, particle_y + period_y))
2158            particle_x = uhstack((particle_x, particle_x))
2159        if self.periodic_y and y1 < self.LE[1]:
2160            particle_y = uhstack((particle_y, particle_y - period_y))
2161            particle_x = uhstack((particle_x, particle_x))
2162        return particle_x, particle_y
2163
2164    def _get_region(self, xlim, ylim, axis, data):
2165        LE, RE = [None] * 3, [None] * 3
2166        ds = data.ds
2167        xax = ds.coordinates.x_axis[axis]
2168        yax = ds.coordinates.y_axis[axis]
2169        zax = axis
2170        LE[xax], RE[xax] = xlim
2171        LE[yax], RE[yax] = ylim
2172        LE[zax] = data.center[zax] - self.width * 0.5
2173        LE[zax].convert_to_units("code_length")
2174        RE[zax] = LE[zax] + self.width
2175        if (
2176            self.region is not None
2177            and np.all(self.region.left_edge <= LE)
2178            and np.all(self.region.right_edge >= RE)
2179        ):
2180            return self.region
2181        self.region = data.ds.region(data.center, LE, RE, data_source=self.data_source)
2182        return self.region
2183
2184
2185class TitleCallback(PlotCallback):
2186    """
2187    Accepts a *title* and adds it to the plot
2188    """
2189
2190    _type_name = "title"
2191
2192    def __init__(self, title):
2193        PlotCallback.__init__(self)
2194        self.title = title
2195
2196    def __call__(self, plot):
2197        plot._axes.set_title(self.title)
2198        # Set the font properties of text from this callback to be
2199        # consistent with other text labels in this figure
2200        label = plot._axes.title
2201        self._set_font_properties(plot, [label])
2202
2203
2204class MeshLinesCallback(PlotCallback):
2205    """
2206    Adds mesh lines to the plot. Only works for unstructured or
2207    semi-structured mesh data. For structured grid data, see
2208    GridBoundaryCallback or CellEdgesCallback.
2209
2210    Parameters
2211    ----------
2212
2213    plot_args:   dict, optional
2214        A dictionary of arguments that will be passed to matplotlib.
2215
2216    Example
2217    -------
2218
2219    >>> import yt
2220    >>> ds = yt.load("MOOSE_sample_data/out.e-s010")
2221    >>> sl = yt.SlicePlot(ds, "z", ("connect2", "convected"))
2222    >>> sl.annotate_mesh_lines(plot_args={"color": "black"})
2223
2224    """
2225
2226    _type_name = "mesh_lines"
2227    _supported_geometries = ("cartesian", "spectral_cube")
2228
2229    def __init__(self, plot_args=None):
2230        super().__init__()
2231        self.plot_args = plot_args
2232
2233    def promote_2d_to_3d(self, coords, indices, plot):
2234        new_coords = np.zeros((2 * coords.shape[0], 3))
2235        new_connects = np.zeros(
2236            (indices.shape[0], 2 * indices.shape[1]), dtype=np.int64
2237        )
2238
2239        new_coords[0 : coords.shape[0], 0:2] = coords
2240        new_coords[0 : coords.shape[0], 2] = plot.ds.domain_left_edge[2]
2241        new_coords[coords.shape[0] :, 0:2] = coords
2242        new_coords[coords.shape[0] :, 2] = plot.ds.domain_right_edge[2]
2243
2244        new_connects[:, 0 : indices.shape[1]] = indices
2245        new_connects[:, indices.shape[1] :] = indices + coords.shape[0]
2246
2247        return new_coords, new_connects
2248
2249    def __call__(self, plot):
2250
2251        index = plot.ds.index
2252        if not issubclass(type(index), UnstructuredIndex):
2253            raise RuntimeError(
2254                "Mesh line annotations only work for "
2255                "unstructured or semi-structured mesh data."
2256            )
2257        for i, m in enumerate(index.meshes):
2258            try:
2259                ftype, fname = plot.field
2260                if ftype.startswith("connect") and int(ftype[-1]) - 1 != i:
2261                    continue
2262            except ValueError:
2263                pass
2264            coords = m.connectivity_coords
2265            indices = m.connectivity_indices - m._index_offset
2266
2267            num_verts = indices.shape[1]
2268            num_dims = coords.shape[1]
2269
2270            if num_dims == 2 and num_verts == 3:
2271                coords, indices = self.promote_2d_to_3d(coords, indices, plot)
2272            elif num_dims == 2 and num_verts == 4:
2273                coords, indices = self.promote_2d_to_3d(coords, indices, plot)
2274
2275            tri_indices = triangulate_indices(indices.astype(np.int_))
2276            points = coords[tri_indices]
2277
2278            tfc = TriangleFacetsCallback(points, plot_args=self.plot_args)
2279            tfc(plot)
2280
2281
2282class TriangleFacetsCallback(PlotCallback):
2283    """
2284    Intended for representing a slice of a triangular faceted
2285    geometry in a slice plot.
2286
2287    Uses a set of *triangle_vertices* to find all triangles the plane of a
2288    SlicePlot intersects with. The lines between the intersection points
2289    of the triangles are then added to the plot to create an outline
2290    of the geometry represented by the triangles.
2291    """
2292
2293    _type_name = "triangle_facets"
2294    _supported_geometries = ("cartesian", "spectral_cube")
2295
2296    def __init__(self, triangle_vertices, plot_args=None):
2297        super().__init__()
2298        self.plot_args = {} if plot_args is None else plot_args
2299        self.vertices = triangle_vertices
2300
2301    def __call__(self, plot):
2302        ax = plot.data.axis
2303        xax = plot.data.ds.coordinates.x_axis[ax]
2304        yax = plot.data.ds.coordinates.y_axis[ax]
2305
2306        if not hasattr(self.vertices, "in_units"):
2307            vertices = plot.data.pf.arr(self.vertices, "code_length")
2308        else:
2309            vertices = self.vertices
2310        l_cy = triangle_plane_intersect(plot.data.axis, plot.data.coord, vertices)[
2311            :, :, (xax, yax)
2312        ]
2313        # l_cy is shape (nlines, 2, 2)
2314        # reformat for conversion to plot coordinates
2315        l_cy = np.rollaxis(l_cy, 0, 3)
2316        # convert all line starting points
2317        l_cy[0] = self._convert_to_plot(plot, l_cy[0])
2318        # convert all line ending points
2319        l_cy[1] = self._convert_to_plot(plot, l_cy[1])
2320        # convert back to shape (nlines, 2, 2)
2321        l_cy = np.rollaxis(l_cy, 2, 0)
2322        # create line collection and add it to the plot
2323        lc = matplotlib.collections.LineCollection(l_cy, **self.plot_args)
2324        plot._axes.add_collection(lc)
2325
2326
2327class TimestampCallback(PlotCallback):
2328    r"""
2329    Annotates the timestamp and/or redshift of the data output at a specified
2330    location in the image (either in a present corner, or by specifying (x,y)
2331    image coordinates with the x_pos, y_pos arguments.  If no time_units are
2332    specified, it will automatically choose appropriate units.  It allows for
2333    custom formatting of the time and redshift information, as well as the
2334    specification of an inset box around the text.
2335
2336    Parameters
2337    ----------
2338
2339    x_pos, y_pos : floats, optional
2340        The image location of the timestamp in the coord system defined by the
2341        coord_system kwarg.  Setting x_pos and y_pos overrides the corner
2342        parameter.
2343
2344    corner : string, optional
2345        Corner sets up one of 4 predeterimined locations for the timestamp
2346        to be displayed in the image: 'upper_left', 'upper_right', 'lower_left',
2347        'lower_right' (also allows None). This value will be overridden by the
2348        optional x_pos and y_pos keywords.
2349
2350    time : boolean, optional
2351        Whether or not to show the ds.current_time of the data output.  Can
2352        be used solo or in conjunction with redshift parameter.
2353
2354    redshift : boolean, optional
2355        Whether or not to show the ds.current_time of the data output.  Can
2356        be used solo or in conjunction with the time parameter.
2357
2358    time_format : string, optional
2359        This specifies the format of the time output assuming "time" is the
2360        number of time and "unit" is units of the time (e.g. 's', 'Myr', etc.)
2361        The time can be specified to arbitrary precision according to printf
2362        formatting codes (defaults to .1f -- a float with 1 digits after
2363        decimal).  Example: "Age = {time:.2f} {units}".
2364
2365    time_unit : string, optional
2366        time_unit must be a valid yt time unit (e.g. 's', 'min', 'hr', 'yr',
2367        'Myr', etc.)
2368
2369    redshift_format : string, optional
2370        This specifies the format of the redshift output.  The redshift can
2371        be specified to arbitrary precision according to printf formatting
2372        codes (defaults to 0.2f -- a float with 2 digits after decimal).
2373        Example: "REDSHIFT = {redshift:03.3g}",
2374
2375    draw_inset_box : boolean, optional
2376        Whether or not an inset box should be included around the text
2377        If so, it uses the inset_box_args to set the matplotlib FancyBboxPatch
2378        object.
2379
2380    coord_system : string, optional
2381        This string defines the coordinate system of the coordinates of pos
2382        Valid coordinates are:
2383
2384        - "data": 3D dataset coordinates
2385        - "plot": 2D coordinates defined by the actual plot limits
2386        - "axis": MPL axis coordinates: (0,0) is lower left; (1,1) is upper right
2387        - "figure": MPL figure coordinates: (0,0) is lower left, (1,1) is upper right
2388
2389    time_offset : float, (value, unit) tuple, or YTQuantity, optional
2390        Apply an offset to the time shown in the annotation from the
2391        value of the current time. If a scalar value with no units is
2392        passed in, the value of the *time_unit* kwarg is used for the
2393        units. Default: None, meaning no offset.
2394
2395    text_args : dictionary, optional
2396        A dictionary of any arbitrary parameters to be passed to the Matplotlib
2397        text object.  Defaults: ``{'color':'white',
2398        'horizontalalignment':'center', 'verticalalignment':'top'}``.
2399
2400    inset_box_args : dictionary, optional
2401        A dictionary of any arbitrary parameters to be passed to the Matplotlib
2402        FancyBboxPatch object as the inset box around the text.
2403        Defaults: ``{'boxstyle':'square', 'pad':0.3, 'facecolor':'black',
2404        'linewidth':3, 'edgecolor':'white', 'alpha':0.5}``
2405
2406    Example
2407    -------
2408
2409    >>> import yt
2410    >>> ds = yt.load("Enzo_64/DD0020/data0020")
2411    >>> s = yt.SlicePlot(ds, "z", "density")
2412    >>> s.annotate_timestamp()
2413    """
2414
2415    _type_name = "timestamp"
2416    _supported_geometries = ("cartesian", "spectral_cube", "cylindrical")
2417
2418    def __init__(
2419        self,
2420        x_pos=None,
2421        y_pos=None,
2422        corner="lower_left",
2423        time=True,
2424        redshift=False,
2425        time_format="t = {time:.1f} {units}",
2426        time_unit=None,
2427        redshift_format="z = {redshift:.2f}",
2428        draw_inset_box=False,
2429        coord_system="axis",
2430        time_offset=None,
2431        text_args=None,
2432        inset_box_args=None,
2433    ):
2434
2435        def_text_args = {
2436            "color": "white",
2437            "horizontalalignment": "center",
2438            "verticalalignment": "top",
2439        }
2440        def_inset_box_args = {
2441            "boxstyle": "square,pad=0.3",
2442            "facecolor": "black",
2443            "linewidth": 3,
2444            "edgecolor": "white",
2445            "alpha": 0.5,
2446        }
2447
2448        # Set position based on corner argument.
2449        self.pos = (x_pos, y_pos)
2450        self.corner = corner
2451        self.time = time
2452        self.redshift = redshift
2453        self.time_format = time_format
2454        self.redshift_format = redshift_format
2455        self.time_unit = time_unit
2456        self.coord_system = coord_system
2457        self.time_offset = time_offset
2458        if text_args is None:
2459            text_args = def_text_args
2460        self.text_args = text_args
2461        if inset_box_args is None:
2462            inset_box_args = def_inset_box_args
2463        self.inset_box_args = inset_box_args
2464
2465        # if inset box is not desired, set inset_box_args to {}
2466        if not draw_inset_box:
2467            self.inset_box_args = None
2468
2469    def __call__(self, plot):
2470        # Setting pos overrides corner argument
2471        if self.pos[0] is None or self.pos[1] is None:
2472            if self.corner == "upper_left":
2473                self.pos = (0.03, 0.96)
2474                self.text_args["horizontalalignment"] = "left"
2475                self.text_args["verticalalignment"] = "top"
2476            elif self.corner == "upper_right":
2477                self.pos = (0.97, 0.96)
2478                self.text_args["horizontalalignment"] = "right"
2479                self.text_args["verticalalignment"] = "top"
2480            elif self.corner == "lower_left":
2481                self.pos = (0.03, 0.03)
2482                self.text_args["horizontalalignment"] = "left"
2483                self.text_args["verticalalignment"] = "bottom"
2484            elif self.corner == "lower_right":
2485                self.pos = (0.97, 0.03)
2486                self.text_args["horizontalalignment"] = "right"
2487                self.text_args["verticalalignment"] = "bottom"
2488            elif self.corner is None:
2489                self.pos = (0.5, 0.5)
2490                self.text_args["horizontalalignment"] = "center"
2491                self.text_args["verticalalignment"] = "center"
2492            else:
2493                raise ValueError(
2494                    "Argument 'corner' must be set to "
2495                    "'upper_left', 'upper_right', 'lower_left', "
2496                    "'lower_right', or None"
2497                )
2498
2499        self.text = ""
2500
2501        # If we're annotating the time, put it in the correct format
2502        if self.time:
2503            # If no time_units are set, then identify a best fit time unit
2504            if self.time_unit is None:
2505                if plot.ds.unit_system._code_flag:
2506                    # if the unit system is in code units
2507                    # we should not convert to seconds for the plot.
2508                    self.time_unit = plot.ds.unit_system.base_units[dimensions.time]
2509                else:
2510                    # in the case of non- code units then we
2511                    self.time_unit = plot.ds.get_smallest_appropriate_unit(
2512                        plot.ds.current_time, quantity="time"
2513                    )
2514            t = plot.ds.current_time.in_units(self.time_unit)
2515            if self.time_offset is not None:
2516                if isinstance(self.time_offset, tuple):
2517                    toffset = plot.ds.quan(self.time_offset[0], self.time_offset[1])
2518                elif isinstance(self.time_offset, Number):
2519                    toffset = plot.ds.quan(self.time_offset, self.time_unit)
2520                elif not isinstance(self.time_offset, YTQuantity):
2521                    raise RuntimeError(
2522                        "'time_offset' must be a float, tuple, or YTQuantity!"
2523                    )
2524                t -= toffset.in_units(self.time_unit)
2525            try:
2526                # here the time unit will be in brackets on the annotation.
2527                un = self.time_unit.latex_representation()
2528                time_unit = r"$\ \ (" + un + r")$"
2529            except AttributeError as err:
2530                if plot.ds.unit_system._code_flag == "code":
2531                    raise RuntimeError(
2532                        "The time unit str repr didn't match expectations, something is wrong."
2533                    ) from err
2534                time_unit = str(self.time_unit).replace("_", " ")
2535            self.text += self.time_format.format(time=float(t), units=time_unit)
2536
2537        # If time and redshift both shown, do one on top of the other
2538        if self.time and self.redshift:
2539            self.text += "\n"
2540
2541        # If we're annotating the redshift, put it in the correct format
2542        if self.redshift:
2543            try:
2544                z = plot.data.ds.current_redshift
2545            except AttributeError:
2546                raise AttributeError(
2547                    "Dataset does not have current_redshift. Set redshift=False."
2548                )
2549            # Replace instances of -0.0* with 0.0* to avoid
2550            # negative null redshifts (e.g., "-0.00").
2551            self.text += self.redshift_format.format(redshift=float(z))
2552            self.text = re.sub("-(0.0*)$", r"\g<1>", self.text)
2553
2554        # This is just a fancy wrapper around the TextLabelCallback
2555        tcb = TextLabelCallback(
2556            self.pos,
2557            self.text,
2558            coord_system=self.coord_system,
2559            text_args=self.text_args,
2560            inset_box_args=self.inset_box_args,
2561        )
2562        return tcb(plot)
2563
2564
2565class ScaleCallback(PlotCallback):
2566    r"""
2567    Annotates the scale of the plot at a specified location in the image
2568    (either in a preset corner, or by specifying (x,y) image coordinates with
2569    the pos argument.  Coeff and units (e.g. 1 Mpc or 100 kpc) refer to the
2570    distance scale you desire to show on the plot.  If no coeff and units are
2571    specified, an appropriate pair will be determined such that your scale bar
2572    is never smaller than min_frac or greater than max_frac of your plottable
2573    axis length.  Additional customization of the scale bar is possible by
2574    adjusting the text_args and size_bar_args dictionaries.  The text_args
2575    dictionary accepts matplotlib's font_properties arguments to override
2576    the default font_properties for the current plot.  The size_bar_args
2577    dictionary accepts keyword arguments for the AnchoredSizeBar class in
2578    matplotlib's axes_grid toolkit.
2579
2580    Parameters
2581    ----------
2582
2583    corner : string, optional
2584        Corner sets up one of 4 predeterimined locations for the scale bar
2585        to be displayed in the image: 'upper_left', 'upper_right', 'lower_left',
2586        'lower_right' (also allows None). This value will be overridden by the
2587        optional 'pos' keyword.
2588
2589    coeff : float, optional
2590        The coefficient of the unit defining the distance scale (e.g. 10 kpc or
2591        100 Mpc) for overplotting.  If set to None along with unit keyword,
2592        coeff will be automatically determined to be a power of 10
2593        relative to the best-fit unit.
2594
2595    unit : string, optional
2596        unit must be a valid yt distance unit (e.g. 'm', 'km', 'AU', 'pc',
2597        'kpc', etc.) or set to None.  If set to None, will be automatically
2598        determined to be the best-fit to the data.
2599
2600    pos : 2- or 3-element tuples, lists, or arrays, optional
2601        The image location of the scale bar in the plot coordinate system.
2602        Setting pos overrides the corner parameter.
2603
2604    min_frac, max_frac: float, optional
2605        The minimum/maximum fraction of the axis width for the scale bar to
2606        extend. A value of 1 would allow the scale bar to extend across the
2607        entire axis width.  Only used for automatically calculating
2608        best-fit coeff and unit when neither is specified, otherwise
2609        disregarded.
2610
2611    coord_system : string, optional
2612        This string defines the coordinate system of the coordinates of pos
2613        Valid coordinates are:
2614
2615        - "data": 3D dataset coordinates
2616        - "plot": 2D coordinates defined by the actual plot limits
2617        - "axis": MPL axis coordinates: (0,0) is lower left; (1,1) is upper right
2618        - "figure": MPL figure coordinates: (0,0) is lower left, (1,1) is upper right
2619
2620    text_args : dictionary, optional
2621        A dictionary of parameters to used to update the font_properties
2622        for the text in this callback.  For any property not set, it will
2623        use the defaults of the plot.  Thus one can modify the text size with
2624        ``text_args={'size':24}``
2625
2626    size_bar_args : dictionary, optional
2627        A dictionary of parameters to be passed to the Matplotlib
2628        AnchoredSizeBar initializer.
2629        Defaults: ``{'pad': 0.25, 'sep': 5, 'borderpad': 1, 'color': 'w'}``
2630
2631    draw_inset_box : boolean, optional
2632        Whether or not an inset box should be included around the scale bar.
2633
2634    inset_box_args : dictionary, optional
2635        A dictionary of keyword arguments to be passed to the matplotlib Patch
2636        object that represents the inset box.
2637        Defaults: ``{'facecolor': 'black', 'linewidth': 3,
2638        'edgecolor': 'white', 'alpha': 0.5, 'boxstyle': 'square'}``
2639
2640    scale_text_format : string, optional
2641        This specifies the format of the scalebar value assuming "scale" is the
2642        numerical value and "unit" is units of the scale (e.g. 'cm', 'kpc', etc.)
2643        The scale can be specified to arbitrary precision according to printf
2644        formatting codes. The format string must only specify "scale" and "units".
2645        Example: "Length = {scale:.2f} {units}". Default: "{scale} {units}"
2646
2647    Example
2648    -------
2649
2650    >>> import yt
2651    >>> ds = yt.load("Enzo_64/DD0020/data0020")
2652    >>> s = yt.SlicePlot(ds, "z", "density")
2653    >>> s.annotate_scale()
2654    """
2655
2656    _type_name = "scale"
2657    _supported_geometries = ("cartesian", "spectral_cube", "force")
2658
2659    def __init__(
2660        self,
2661        corner="lower_right",
2662        coeff=None,
2663        unit=None,
2664        pos=None,
2665        max_frac=0.16,
2666        min_frac=0.015,
2667        coord_system="axis",
2668        text_args=None,
2669        size_bar_args=None,
2670        draw_inset_box=False,
2671        inset_box_args=None,
2672        scale_text_format="{scale} {units}",
2673    ):
2674
2675        def_size_bar_args = {"pad": 0.05, "sep": 5, "borderpad": 1, "color": "w"}
2676
2677        def_inset_box_args = {
2678            "facecolor": "black",
2679            "linewidth": 3,
2680            "edgecolor": "white",
2681            "alpha": 0.5,
2682            "boxstyle": "square",
2683        }
2684
2685        # Set position based on corner argument.
2686        self.corner = corner
2687        self.coeff = coeff
2688        self.unit = unit
2689        self.pos = pos
2690        self.max_frac = max_frac
2691        self.min_frac = min_frac
2692        self.coord_system = coord_system
2693        self.scale_text_format = scale_text_format
2694        if size_bar_args is None:
2695            self.size_bar_args = def_size_bar_args
2696        else:
2697            self.size_bar_args = size_bar_args
2698        if inset_box_args is None:
2699            self.inset_box_args = def_inset_box_args
2700        else:
2701            self.inset_box_args = inset_box_args
2702        self.draw_inset_box = draw_inset_box
2703        if text_args is None:
2704            text_args = {}
2705        self.text_args = text_args
2706
2707    def __call__(self, plot):
2708        from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
2709
2710        # Callback only works for plots with axis ratios of 1
2711        xsize = plot.xlim[1] - plot.xlim[0]
2712
2713        # Setting pos overrides corner argument
2714        if self.pos is None:
2715            if self.corner == "upper_left":
2716                self.pos = (0.11, 0.952)
2717            elif self.corner == "upper_right":
2718                self.pos = (0.89, 0.952)
2719            elif self.corner == "lower_left":
2720                self.pos = (0.11, 0.052)
2721            elif self.corner == "lower_right":
2722                self.pos = (0.89, 0.052)
2723            elif self.corner is None:
2724                self.pos = (0.5, 0.5)
2725            else:
2726                raise ValueError(
2727                    "Argument 'corner' must be set to "
2728                    "'upper_left', 'upper_right', 'lower_left', "
2729                    "'lower_right', or None"
2730                )
2731
2732        # When identifying a best fit distance unit, do not allow scale marker
2733        # to be greater than max_frac fraction of xaxis or under min_frac
2734        # fraction of xaxis
2735        max_scale = self.max_frac * xsize
2736        min_scale = self.min_frac * xsize
2737
2738        # If no units are set, pick something sensible.
2739        if self.unit is None:
2740            # User has set the axes units and supplied a coefficient.
2741            if plot._axes_unit_names is not None and self.coeff is not None:
2742                self.unit = plot._axes_unit_names[0]
2743            # Nothing provided; identify a best fit distance unit.
2744            else:
2745                min_scale = plot.ds.get_smallest_appropriate_unit(
2746                    min_scale, return_quantity=True
2747                )
2748                max_scale = plot.ds.get_smallest_appropriate_unit(
2749                    max_scale, return_quantity=True
2750                )
2751                if self.coeff is None:
2752                    self.coeff = max_scale.v
2753                self.unit = max_scale.units
2754        elif self.coeff is None:
2755            self.coeff = 1
2756        self.scale = plot.ds.quan(self.coeff, self.unit)
2757        text = self.scale_text_format.format(scale=int(self.coeff), units=self.unit)
2758        image_scale = (
2759            plot.frb.convert_distance_x(self.scale) / plot.frb.convert_distance_x(xsize)
2760        ).v
2761        size_vertical = self.size_bar_args.pop("size_vertical", 0.005 * plot.aspect)
2762        fontproperties = self.size_bar_args.pop(
2763            "fontproperties", plot.font_properties.copy()
2764        )
2765        frameon = self.size_bar_args.pop("frameon", self.draw_inset_box)
2766        # FontProperties instances use set_<property>() setter functions
2767        for key, val in self.text_args.items():
2768            setter_func = "set_" + key
2769            try:
2770                getattr(fontproperties, setter_func)(val)
2771            except AttributeError as e:
2772                raise AttributeError(
2773                    "Cannot set text_args keyword "
2774                    "to include '%s' because MPL's fontproperties object does "
2775                    "not contain function '%s'." % (key, setter_func)
2776                ) from e
2777
2778        # this "anchors" the size bar to a box centered on self.pos in axis
2779        # coordinates
2780        self.size_bar_args["bbox_to_anchor"] = self.pos
2781        self.size_bar_args["bbox_transform"] = plot._axes.transAxes
2782
2783        bar = AnchoredSizeBar(
2784            plot._axes.transAxes,
2785            image_scale,
2786            text,
2787            10,
2788            size_vertical=size_vertical,
2789            fontproperties=fontproperties,
2790            frameon=frameon,
2791            **self.size_bar_args,
2792        )
2793
2794        bar.patch.set(**self.inset_box_args)
2795
2796        plot._axes.add_artist(bar)
2797
2798        return plot
2799
2800
2801class RayCallback(PlotCallback):
2802    """
2803    Adds a line representing the projected path of a ray across the plot.
2804    The ray can be either a YTOrthoRay, YTRay, or a LightRay object.
2805    annotate_ray() will properly account for periodic rays across the volume.
2806    If arrow is set to True, uses the MPL.pyplot.arrow function, otherwise
2807    uses the MPL.pyplot.plot function to plot a normal line.  Adjust
2808    plot_args accordingly.
2809
2810    Parameters
2811    ----------
2812
2813    ray : YTOrthoRay, YTRay, or LightRay
2814        Ray is the object that we want to include.  We overplot the projected
2815        trajectory of the ray.  If the object is a trident.LightRay
2816        object, it will only plot the segment of the LightRay that intersects
2817        the dataset currently displayed.
2818
2819    arrow : boolean, optional
2820        Whether or not to place an arrowhead on the front of the ray to denote
2821        direction
2822        Default: False
2823
2824    plot_args : dictionary, optional
2825        A dictionary of any arbitrary parameters to be passed to the Matplotlib
2826        line object.  Defaults: {'color':'white', 'linewidth':2}.
2827
2828    Examples
2829    --------
2830
2831    >>> # Overplot a ray and an ortho_ray object on a projection
2832    >>> import yt
2833    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
2834    >>> oray = ds.ortho_ray(1, (0.3, 0.4))  # orthoray down the y axis
2835    >>> ray = ds.ray((0.1, 0.2, 0.3), (0.6, 0.7, 0.8))  # arbitrary ray
2836    >>> p = yt.ProjectionPlot(ds, "z", "density")
2837    >>> p.annotate_ray(oray)
2838    >>> p.annotate_ray(ray)
2839    >>> p.save()
2840
2841    >>> # Overplot a LightRay object on a projection
2842    >>> import yt
2843    >>> from trident import LightRay
2844    >>> ds = yt.load("enzo_cosmology_plus/RD0004/RD0004")
2845    >>> lr = LightRay(
2846    ...     "enzo_cosmology_plus/AMRCosmology.enzo", "Enzo", 0.0, 0.1, time_data=False
2847    ... )
2848    >>> lray = lr.make_light_ray(seed=1)
2849    >>> p = yt.ProjectionPlot(ds, "z", "density")
2850    >>> p.annotate_ray(lr)
2851    >>> p.save()
2852
2853    """
2854
2855    _type_name = "ray"
2856    _supported_geometries = ("cartesian", "spectral_cube", "force")
2857
2858    def __init__(self, ray, arrow=False, plot_args=None):
2859        PlotCallback.__init__(self)
2860        def_plot_args = {"color": "white", "linewidth": 2}
2861        self.ray = ray
2862        self.arrow = arrow
2863        if plot_args is None:
2864            plot_args = def_plot_args
2865        self.plot_args = plot_args
2866
2867    def _process_ray(self):
2868        """
2869        Get the start_coord and end_coord of a ray object
2870        """
2871        return (self.ray.start_point, self.ray.end_point)
2872
2873    def _process_ortho_ray(self):
2874        """
2875        Get the start_coord and end_coord of an ortho_ray object
2876        """
2877        start_coord = self.ray.ds.domain_left_edge.copy()
2878        end_coord = self.ray.ds.domain_right_edge.copy()
2879
2880        xax = self.ray.ds.coordinates.x_axis[self.ray.axis]
2881        yax = self.ray.ds.coordinates.y_axis[self.ray.axis]
2882        start_coord[xax] = end_coord[xax] = self.ray.coords[0]
2883        start_coord[yax] = end_coord[yax] = self.ray.coords[1]
2884        return (start_coord, end_coord)
2885
2886    def _process_light_ray(self, plot):
2887        """
2888        Get the start_coord and end_coord of a LightRay object.
2889        Identify which of the sections of the LightRay is in the
2890        dataset that is currently being plotted.  If there is one, return the
2891        start and end of the corresponding ray segment
2892        """
2893
2894        for ray_ds in self.ray.light_ray_solution:
2895            if ray_ds["unique_identifier"] == str(plot.ds.unique_identifier):
2896                start_coord = plot.ds.arr(ray_ds["start"])
2897                end_coord = plot.ds.arr(ray_ds["end"])
2898                return (start_coord, end_coord)
2899        # if no intersection between the plotted dataset and the LightRay
2900        # return a false tuple to pass to start_coord
2901        return ((False, False), (False, False))
2902
2903    def __call__(self, plot):
2904        type_name = getattr(self.ray, "_type_name", None)
2905
2906        if type_name == "ray":
2907            start_coord, end_coord = self._process_ray()
2908
2909        elif type_name == "ortho_ray":
2910            start_coord, end_coord = self._process_ortho_ray()
2911
2912        elif hasattr(self.ray, "light_ray_solution"):
2913            start_coord, end_coord = self._process_light_ray(plot)
2914
2915        else:
2916            raise ValueError("ray must be a YTRay, YTOrthoRay, or LightRay object.")
2917
2918        # if start_coord and end_coord are all False, it means no intersecting
2919        # ray segment with this plot.
2920        if not all(start_coord) and not all(end_coord):
2921            return plot
2922
2923        # if possible, break periodic ray into non-periodic
2924        # segments and add each of them individually
2925        if any(plot.ds.periodicity):
2926            segments = periodic_ray(
2927                start_coord.to("code_length"),
2928                end_coord.to("code_length"),
2929                left=plot.ds.domain_left_edge.to("code_length"),
2930                right=plot.ds.domain_right_edge.to("code_length"),
2931            )
2932        else:
2933            segments = [[start_coord, end_coord]]
2934
2935        # To assure that the last ray segment has an arrow if so desired
2936        # and all other ray segments are lines
2937        for segment in segments[:-1]:
2938            cb = LinePlotCallback(
2939                segment[0], segment[1], coord_system="data", plot_args=self.plot_args
2940            )
2941            cb(plot)
2942        segment = segments[-1]
2943        if self.arrow:
2944            cb = ArrowCallback(
2945                segment[1],
2946                starting_pos=segment[0],
2947                coord_system="data",
2948                plot_args=self.plot_args,
2949            )
2950        else:
2951            cb = LinePlotCallback(
2952                segment[0], segment[1], coord_system="data", plot_args=self.plot_args
2953            )
2954        cb(plot)
2955        return plot
2956
2957
2958class LineIntegralConvolutionCallback(PlotCallback):
2959    """
2960    Add the line integral convolution to the plot for vector fields
2961    visualization. Two component of vector fields needed to be provided
2962    (i.e., velocity_x and velocity_y, magnetic_field_x and magnetic_field_y).
2963
2964    Parameters
2965    ----------
2966
2967    field_x, field_y : string
2968        The names of two components of vector field which will be visualized
2969
2970    texture : 2-d array with the same shape of image, optional
2971        Texture will be convolved when computing line integral convolution.
2972        A white noise background will be used as default.
2973
2974    kernellen : float, optional
2975        The lens of kernel for convolution, which is the length over which the
2976        convolution will be performed. For longer kernellen, longer streamline
2977        structure will appear.
2978
2979    lim : 2-element tuple, list, or array, optional
2980        The value of line integral convolution will be clipped to the range
2981        of lim, which applies upper and lower bounds to the values of line
2982        integral convolution and enhance the visibility of plots. Each element
2983        should be in the range of [0,1].
2984
2985    cmap : string, optional
2986        The name of colormap for line integral convolution plot.
2987
2988    alpha : float, optional
2989        The alpha value for line integral convolution plot.
2990
2991    const_alpha : boolean, optional
2992        If set to False (by default), alpha will be weighted spatially by
2993        the values of line integral convolution; otherwise a constant value
2994        of the given alpha is used.
2995
2996    Example
2997    -------
2998
2999    >>> import yt
3000    >>> ds = yt.load("Enzo_64/DD0020/data0020")
3001    >>> s = yt.SlicePlot(ds, "z", "density")
3002    >>> s.annotate_line_integral_convolution(
3003    ...     "velocity_x", "velocity_y", lim=(0.5, 0.65)
3004    ... )
3005    """
3006
3007    _type_name = "line_integral_convolution"
3008    _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical")
3009
3010    def __init__(
3011        self,
3012        field_x,
3013        field_y,
3014        texture=None,
3015        kernellen=50.0,
3016        lim=(0.5, 0.6),
3017        cmap="binary",
3018        alpha=0.8,
3019        const_alpha=False,
3020    ):
3021        PlotCallback.__init__(self)
3022        self.field_x = field_x
3023        self.field_y = field_y
3024        self.texture = texture
3025        self.kernellen = kernellen
3026        self.lim = lim
3027        self.cmap = cmap
3028        self.alpha = alpha
3029        self.const_alpha = const_alpha
3030
3031    def __call__(self, plot):
3032        from matplotlib import cm
3033
3034        bounds = self._physical_bounds(plot)
3035        extent = self._plot_bounds(plot)
3036
3037        # We are feeding this size into the pixelizer, where it will properly
3038        # set it in reverse order
3039        nx = plot.image._A.shape[1]
3040        ny = plot.image._A.shape[0]
3041        pixX = plot.data.ds.coordinates.pixelize(
3042            plot.data.axis, plot.data, self.field_x, bounds, (nx, ny)
3043        )
3044        pixY = plot.data.ds.coordinates.pixelize(
3045            plot.data.axis, plot.data, self.field_y, bounds, (nx, ny)
3046        )
3047
3048        vectors = np.concatenate((pixX[..., np.newaxis], pixY[..., np.newaxis]), axis=2)
3049
3050        if self.texture is None:
3051            self.texture = np.random.rand(nx, ny).astype(np.double)
3052        elif self.texture.shape != (nx, ny):
3053            raise ValueError(
3054                "'texture' must have the same shape "
3055                "with that of output image (%d, %d)" % (nx, ny)
3056            )
3057
3058        kernel = np.sin(np.arange(self.kernellen) * np.pi / self.kernellen)
3059        kernel = kernel.astype(np.double)
3060
3061        lic_data = line_integral_convolution_2d(vectors, self.texture, kernel)
3062        lic_data = lic_data / lic_data.max()
3063        lic_data_clip = np.clip(lic_data, self.lim[0], self.lim[1])
3064
3065        if self.const_alpha:
3066            plot._axes.imshow(
3067                lic_data_clip,
3068                extent=extent,
3069                cmap=self.cmap,
3070                alpha=self.alpha,
3071                origin="lower",
3072                aspect="auto",
3073            )
3074        else:
3075            lic_data_rgba = cm.ScalarMappable(norm=None, cmap=self.cmap).to_rgba(
3076                lic_data_clip
3077            )
3078            lic_data_clip_rescale = (lic_data_clip - self.lim[0]) / (
3079                self.lim[1] - self.lim[0]
3080            )
3081            lic_data_rgba[..., 3] = lic_data_clip_rescale * self.alpha
3082            plot._axes.imshow(
3083                lic_data_rgba,
3084                extent=extent,
3085                cmap=self.cmap,
3086                origin="lower",
3087                aspect="auto",
3088            )
3089
3090        return plot
3091
3092
3093class CellEdgesCallback(PlotCallback):
3094    """
3095    Annotate cell edges.  This is done through a second call to pixelize, where
3096    the distance from a pixel to a cell boundary in pixels is compared against
3097    the `line_width` argument.  The secondary image is colored as `color` and
3098    overlaid with the `alpha` value.
3099
3100    Parameters
3101    ----------
3102    line_width : float
3103        The width of the cell edge lines in normalized units relative to the
3104        size of the longest axis.  Default is 1% of the size of the smallest
3105        axis.
3106    alpha : float
3107        When the second image is overlaid, it will have this level of alpha
3108        transparency.  Default is 1.0 (fully-opaque).
3109    color : tuple of three floats or matplotlib color name
3110        This is the color of the cell edge values.  It defaults to black.
3111
3112    Examples
3113    --------
3114
3115    >>> import yt
3116    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
3117    >>> s = yt.SlicePlot(ds, "z", "density")
3118    >>> s.annotate_cell_edges()
3119    >>> s.save()
3120    """
3121
3122    _type_name = "cell_edges"
3123    _supported_geometries = ("cartesian", "spectral_cube", "cylindrical")
3124
3125    def __init__(self, line_width=0.002, alpha=1.0, color="black"):
3126        from matplotlib.colors import ColorConverter
3127
3128        conv = ColorConverter()
3129        PlotCallback.__init__(self)
3130        self.line_width = line_width
3131        self.alpha = alpha
3132        self.color = (np.array(conv.to_rgb(color)) * 255).astype("uint8")
3133
3134    def __call__(self, plot):
3135        if plot.data.ds.geometry == "cylindrical" and plot.data.ds.dimensionality == 3:
3136            raise NotImplementedError(
3137                "Cell edge annotation is only supported for \
3138                for 2D cylindrical geometry, not 3D"
3139            )
3140        x0, x1, y0, y1 = self._physical_bounds(plot)
3141        xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
3142        nx = plot.image._A.shape[1]
3143        ny = plot.image._A.shape[0]
3144        aspect = float((y1 - y0) / (x1 - x0))
3145        pixel_aspect = float(ny) / nx
3146        relative_aspect = pixel_aspect / aspect
3147        if relative_aspect > 1:
3148            nx = int(nx / relative_aspect)
3149        else:
3150            ny = int(ny * relative_aspect)
3151        if aspect > 1:
3152            if nx < 1600:
3153                nx = int(1600.0 / nx * ny)
3154                ny = 1600
3155            long_axis = ny
3156        else:
3157            if ny < 1600:
3158                nx = int(1600.0 / ny * nx)
3159                ny = 1600
3160            long_axis = nx
3161        line_width = max(self.line_width * long_axis, 1.0)
3162        im = np.zeros((ny, nx), dtype="f8")
3163        pixelize_cartesian(
3164            im,
3165            plot.data["px"],
3166            plot.data["py"],
3167            plot.data["pdx"],
3168            plot.data["pdy"],
3169            plot.data["px"],  # dummy field
3170            (x0, x1, y0, y1),
3171            line_width=line_width,
3172        )
3173        # New image:
3174        im_buffer = np.zeros((ny, nx, 4), dtype="uint8")
3175        im_buffer[im > 0, 3] = 255
3176        im_buffer[im > 0, :3] = self.color
3177        plot._axes.imshow(
3178            im_buffer,
3179            origin="lower",
3180            interpolation="bilinear",
3181            extent=[xx0, xx1, yy0, yy1],
3182            alpha=self.alpha,
3183        )
3184        plot._axes.set_xlim(xx0, xx1)
3185        plot._axes.set_ylim(yy0, yy1)
3186