1from io import BytesIO
2
3import matplotlib
4import numpy as np
5from packaging.version import parse as parse_version
6
7from yt.funcs import (
8    get_brewer_cmap,
9    get_interactivity,
10    is_sequence,
11    matplotlib_style_context,
12    mylog,
13)
14
15from ._commons import get_canvas, validate_image_name
16
17BACKEND_SPECS = {
18    "GTK": ["backend_gtk", "FigureCanvasGTK", "FigureManagerGTK"],
19    "GTKAgg": ["backend_gtkagg", "FigureCanvasGTKAgg", None],
20    "GTKCairo": ["backend_gtkcairo", "FigureCanvasGTKCairo", None],
21    "MacOSX": ["backend_macosx", "FigureCanvasMac", "FigureManagerMac"],
22    "Qt4Agg": ["backend_qt4agg", "FigureCanvasQTAgg", None],
23    "Qt5Agg": ["backend_qt5agg", "FigureCanvasQTAgg", None],
24    "TkAgg": ["backend_tkagg", "FigureCanvasTkAgg", None],
25    "WX": ["backend_wx", "FigureCanvasWx", None],
26    "WXAgg": ["backend_wxagg", "FigureCanvasWxAgg", None],
27    "GTK3Cairo": [
28        "backend_gtk3cairo",
29        "FigureCanvasGTK3Cairo",
30        "FigureManagerGTK3Cairo",
31    ],
32    "GTK3Agg": ["backend_gtk3agg", "FigureCanvasGTK3Agg", "FigureManagerGTK3Agg"],
33    "WebAgg": ["backend_webagg", "FigureCanvasWebAgg", None],
34    "nbAgg": ["backend_nbagg", "FigureCanvasNbAgg", "FigureManagerNbAgg"],
35    "agg": ["backend_agg", "FigureCanvasAgg", None],
36}
37
38
39class CallbackWrapper:
40    def __init__(self, viewer, window_plot, frb, field, font_properties, font_color):
41        self.frb = frb
42        self.data = frb.data_source
43        self._axes = window_plot.axes
44        self._figure = window_plot.figure
45        if len(self._axes.images) > 0:
46            self.image = self._axes.images[0]
47        if frb.axis < 3:
48            DD = frb.ds.domain_width
49            xax = frb.ds.coordinates.x_axis[frb.axis]
50            yax = frb.ds.coordinates.y_axis[frb.axis]
51            self._period = (DD[xax], DD[yax])
52        self.ds = frb.ds
53        self.xlim = viewer.xlim
54        self.ylim = viewer.ylim
55        self._axes_unit_names = viewer._axes_unit_names
56        if "OffAxisSlice" in viewer._plot_type:
57            self._type_name = "CuttingPlane"
58        else:
59            self._type_name = viewer._plot_type
60        self.aspect = window_plot._aspect
61        self.font_properties = font_properties
62        self.font_color = font_color
63        self.field = field
64
65
66class PlotMPL:
67    """A base class for all yt plots made using matplotlib, that is backend independent."""
68
69    def __init__(self, fsize, axrect, figure, axes):
70        """Initialize PlotMPL class"""
71        import matplotlib.figure
72
73        self._plot_valid = True
74        if figure is None:
75            if not is_sequence(fsize):
76                fsize = (fsize, fsize)
77            self.figure = matplotlib.figure.Figure(figsize=fsize, frameon=True)
78        else:
79            figure.set_size_inches(fsize)
80            self.figure = figure
81        if axes is None:
82            self._create_axes(axrect)
83        else:
84            axes.cla()
85            axes.set_position(axrect)
86            self.axes = axes
87        self.interactivity = get_interactivity()
88
89        figure_canvas, figure_manager = self._get_canvas_classes()
90        self.canvas = figure_canvas(self.figure)
91        if figure_manager is not None:
92            self.manager = figure_manager(self.canvas, 1)
93
94        for which in ["major", "minor"]:
95            for axis in "xy":
96                self.axes.tick_params(
97                    which=which, axis=axis, direction="in", top=True, right=True
98                )
99
100    def _create_axes(self, axrect):
101        self.axes = self.figure.add_axes(axrect)
102
103    def _get_canvas_classes(self):
104
105        if self.interactivity:
106            key = str(matplotlib.get_backend())
107        else:
108            key = "agg"
109
110        try:
111            module, fig_canvas, fig_manager = BACKEND_SPECS[key]
112        except KeyError:
113            return
114
115        mod = __import__(
116            "matplotlib.backends",
117            globals(),
118            locals(),
119            [module],
120            0,
121        )
122        submod = getattr(mod, module)
123        FigureCanvas = getattr(submod, fig_canvas)
124        if fig_manager is not None:
125            FigureManager = getattr(submod, fig_manager)
126            return FigureCanvas, FigureManager
127
128        return FigureCanvas, None
129
130    def save(self, name, mpl_kwargs=None, canvas=None):
131        """Choose backend and save image to disk"""
132
133        if mpl_kwargs is None:
134            mpl_kwargs = {}
135        if "papertype" not in mpl_kwargs and parse_version(
136            matplotlib.__version__
137        ) < parse_version("3.3.0"):
138            mpl_kwargs["papertype"] = "auto"
139
140        name = validate_image_name(name)
141
142        try:
143            canvas = get_canvas(self.figure, name)
144        except ValueError:
145            canvas = self.canvas
146
147        mylog.info("Saving plot %s", name)
148        with matplotlib_style_context():
149            canvas.print_figure(name, **mpl_kwargs)
150        return name
151
152    def show(self):
153        try:
154            self.manager.show()
155        except AttributeError:
156            self.canvas.show()
157
158    def _get_labels(self):
159        ax = self.axes
160        labels = ax.xaxis.get_ticklabels() + ax.yaxis.get_ticklabels()
161        labels += ax.xaxis.get_minorticklabels()
162        labels += ax.yaxis.get_minorticklabels()
163        labels += [
164            ax.title,
165            ax.xaxis.label,
166            ax.yaxis.label,
167            ax.xaxis.get_offset_text(),
168            ax.yaxis.get_offset_text(),
169        ]
170        return labels
171
172    def _set_font_properties(self, font_properties, font_color):
173        for label in self._get_labels():
174            label.set_fontproperties(font_properties)
175            if font_color is not None:
176                label.set_color(self.font_color)
177
178    def _repr_png_(self):
179        from ._mpl_imports import FigureCanvasAgg
180
181        canvas = FigureCanvasAgg(self.figure)
182        f = BytesIO()
183        with matplotlib_style_context():
184            canvas.print_figure(f)
185        f.seek(0)
186        return f.read()
187
188
189class ImagePlotMPL(PlotMPL):
190    """A base class for yt plots made using imshow"""
191
192    def __init__(self, fsize, axrect, caxrect, zlim, figure, axes, cax):
193        """Initialize ImagePlotMPL class object"""
194        super().__init__(fsize, axrect, figure, axes)
195        self.zmin, self.zmax = zlim
196        if cax is None:
197            self.cax = self.figure.add_axes(caxrect)
198        else:
199            cax.cla()
200            cax.set_position(caxrect)
201            self.cax = cax
202
203    def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect):
204        """Store output of imshow in image variable"""
205        cbnorm_kwargs = dict(
206            vmin=float(self.zmin) if self.zmin is not None else None,
207            vmax=float(self.zmax) if self.zmax is not None else None,
208        )
209        if cbnorm == "log10":
210            cbnorm_cls = matplotlib.colors.LogNorm
211        elif cbnorm == "linear":
212            cbnorm_cls = matplotlib.colors.Normalize
213        elif cbnorm == "symlog":
214            # if cblinthresh is not specified, try to come up with a reasonable default
215            vmin = float(np.nanmin(data))
216            vmax = float(np.nanmax(data))
217            if cblinthresh is None:
218                cblinthresh = np.nanmin(np.absolute(data)[data != 0])
219
220            cbnorm_kwargs.update(dict(linthresh=cblinthresh, vmin=vmin, vmax=vmax))
221            MPL_VERSION = parse_version(matplotlib.__version__)
222            if MPL_VERSION >= parse_version("3.2.0"):
223                # note that this creates an inconsistency between mpl versions
224                # since the default value previous to mpl 3.4.0 is np.e
225                # but it is only exposed since 3.2.0
226                cbnorm_kwargs["base"] = 10
227
228            cbnorm_cls = matplotlib.colors.SymLogNorm
229        else:
230            raise ValueError(f"Unknown value `cbnorm` == {cbnorm}")
231
232        norm = cbnorm_cls(**cbnorm_kwargs)
233
234        extent = [float(e) for e in extent]
235        # tuple colormaps are from palettable (or brewer2mpl)
236        if isinstance(cmap, tuple):
237            cmap = get_brewer_cmap(cmap)
238
239        if self._transform is None:
240            # sets the transform to be an ax.TransData object, where the
241            # coordiante system of the data is controlled by the xlim and ylim
242            # of the data.
243            transform = self.axes.transData
244        else:
245            transform = self._transform
246        if hasattr(self.axes, "set_extent"):
247            # CartoPy hangs if we do not set_extent before imshow if we are
248            # displaying a small subset of the globe.  What I believe happens is
249            # that the transform for the points on the outside results in
250            # infinities, and then the scipy.spatial cKDTree hangs trying to
251            # identify nearest points.
252            #
253            # Also, set_extent is defined by cartopy, so not all axes will have
254            # it as a method.
255            #
256            # A potential downside is that other images may change, but I believe
257            # the result of imshow is to set_extent *regardless*.  This just
258            # changes the order in which it happens.
259            #
260            # NOTE: This is currently commented out because it breaks in some
261            # instances.  It is left as a historical note because we will
262            # eventually need some form of it.
263            # self.axes.set_extent(extent)
264            pass
265        self.image = self.axes.imshow(
266            data.to_ndarray(),
267            origin="lower",
268            extent=extent,
269            norm=norm,
270            aspect=aspect,
271            cmap=cmap,
272            interpolation="nearest",
273            transform=transform,
274        )
275        if cbnorm == "symlog":
276            formatter = matplotlib.ticker.LogFormatterMathtext(linthresh=cblinthresh)
277            self.cb = self.figure.colorbar(self.image, self.cax, format=formatter)
278            if np.nanmin(data) >= 0.0:
279                yticks = [np.nanmin(data).v] + list(
280                    10
281                    ** np.arange(
282                        np.rint(np.log10(cblinthresh)),
283                        np.ceil(np.log10(np.nanmax(data))) + 1,
284                    )
285                )
286            elif np.nanmax(data) <= 0.0:
287                yticks = (
288                    list(
289                        -(
290                            10
291                            ** np.arange(
292                                np.floor(np.log10(-np.nanmin(data))),
293                                np.rint(np.log10(cblinthresh)) - 1,
294                                -1,
295                            )
296                        )
297                    )
298                    + [np.nanmax(data).v]
299                )
300            else:
301                yticks = (
302                    list(
303                        -(
304                            10
305                            ** np.arange(
306                                np.floor(np.log10(-np.nanmin(data))),
307                                np.rint(np.log10(cblinthresh)) - 1,
308                                -1,
309                            )
310                        )
311                    )
312                    + [0]
313                    + list(
314                        10
315                        ** np.arange(
316                            np.rint(np.log10(cblinthresh)),
317                            np.ceil(np.log10(np.nanmax(data))) + 1,
318                        )
319                    )
320                )
321            self.cb.set_ticks(yticks)
322        else:
323            self.cb = self.figure.colorbar(self.image, self.cax)
324        for which in ["major", "minor"]:
325            self.cax.tick_params(which=which, axis="y", direction="in")
326
327    def _get_best_layout(self):
328
329        # Ensure the figure size along the long axis is always equal to _figure_size
330        if is_sequence(self._figure_size):
331            x_fig_size = self._figure_size[0]
332            y_fig_size = self._figure_size[1]
333        else:
334            x_fig_size = self._figure_size
335            y_fig_size = self._figure_size / self._aspect
336
337        if hasattr(self, "_unit_aspect"):
338            y_fig_size = y_fig_size * self._unit_aspect
339
340        if self._draw_colorbar:
341            cb_size = self._cb_size
342            cb_text_size = self._ax_text_size[1] + 0.45
343        else:
344            cb_size = x_fig_size * 0.04
345            cb_text_size = 0.0
346
347        if self._draw_axes:
348            x_axis_size = self._ax_text_size[0]
349            y_axis_size = self._ax_text_size[1]
350        else:
351            x_axis_size = x_fig_size * 0.04
352            y_axis_size = y_fig_size * 0.04
353
354        top_buff_size = self._top_buff_size
355
356        if not self._draw_axes and not self._draw_colorbar:
357            x_axis_size = 0.0
358            y_axis_size = 0.0
359            cb_size = 0.0
360            cb_text_size = 0.0
361            top_buff_size = 0.0
362
363        xbins = np.array([x_axis_size, x_fig_size, cb_size, cb_text_size])
364        ybins = np.array([y_axis_size, y_fig_size, top_buff_size])
365
366        size = [xbins.sum(), ybins.sum()]
367
368        x_frac_widths = xbins / size[0]
369        y_frac_widths = ybins / size[1]
370
371        # axrect is the rectangle defining the area of the
372        # axis object of the plot.  Its range goes from 0 to 1 in
373        # x and y directions.  The first two values are the x,y
374        # start values of the axis object (lower left corner), and the
375        # second two values are the size of the axis object.  To get
376        # the upper right corner, add the first x,y to the second x,y.
377        axrect = (
378            x_frac_widths[0],
379            y_frac_widths[0],
380            x_frac_widths[1],
381            y_frac_widths[1],
382        )
383
384        # caxrect is the rectangle defining the area of the colorbar
385        # axis object of the plot.  It is defined just as the axrect
386        # tuple is.
387        caxrect = (
388            x_frac_widths[0] + x_frac_widths[1],
389            y_frac_widths[0],
390            x_frac_widths[2],
391            y_frac_widths[1],
392        )
393
394        return size, axrect, caxrect
395
396    def _toggle_axes(self, choice, draw_frame=None):
397        """
398        Turn on/off displaying the axis ticks and labels for a plot.
399
400        Parameters
401        ----------
402        choice : boolean
403            If True, set the axes to be drawn. If False, set the axes to not be
404            drawn.
405        """
406        if draw_frame is None:
407            draw_frame = choice
408        self._draw_axes = choice
409        self._draw_frame = draw_frame
410        self.axes.set_frame_on(draw_frame)
411        self.axes.get_xaxis().set_visible(choice)
412        self.axes.get_yaxis().set_visible(choice)
413        size, axrect, caxrect = self._get_best_layout()
414        self.axes.set_position(axrect)
415        self.cax.set_position(caxrect)
416        self.figure.set_size_inches(*size)
417
418    def _toggle_colorbar(self, choice):
419        """
420        Turn on/off displaying the colorbar for a plot
421
422        choice = True or False
423        """
424        self._draw_colorbar = choice
425        self.cax.set_visible(choice)
426        size, axrect, caxrect = self._get_best_layout()
427        self.axes.set_position(axrect)
428        self.cax.set_position(caxrect)
429        self.figure.set_size_inches(*size)
430
431    def _get_labels(self):
432        labels = super()._get_labels()
433        cbax = self.cb.ax
434        labels += cbax.yaxis.get_ticklabels()
435        labels += [cbax.yaxis.label, cbax.yaxis.get_offset_text()]
436        return labels
437
438    def hide_axes(self, draw_frame=None):
439        """
440        Hide the axes for a plot including ticks and labels
441        """
442        self._toggle_axes(False, draw_frame)
443        return self
444
445    def show_axes(self):
446        """
447        Show the axes for a plot including ticks and labels
448        """
449        self._toggle_axes(True)
450        return self
451
452    def hide_colorbar(self):
453        """
454        Hide the colorbar for a plot including ticks and labels
455        """
456        self._toggle_colorbar(False)
457        return self
458
459    def show_colorbar(self):
460        """
461        Show the colorbar for a plot including ticks and labels
462        """
463        self._toggle_colorbar(True)
464        return self
465
466
467def get_multi_plot(nx, ny, colorbar="vertical", bw=4, dpi=300, cbar_padding=0.4):
468    r"""Construct a multiple axes plot object, with or without a colorbar, into
469    which multiple plots may be inserted.
470
471    This will create a set of :class:`matplotlib.axes.Axes`, all lined up into
472    a grid, which are then returned to the user and which can be used to plot
473    multiple plots on a single figure.
474
475    Parameters
476    ----------
477    nx : int
478        Number of axes to create along the x-direction
479    ny : int
480        Number of axes to create along the y-direction
481    colorbar : {'vertical', 'horizontal', None}, optional
482        Should Axes objects for colorbars be allocated, and if so, should they
483        correspond to the horizontal or vertical set of axes?
484    bw : number
485        The base height/width of an axes object inside the figure, in inches
486    dpi : number
487        The dots per inch fed into the Figure instantiation
488
489    Returns
490    -------
491    fig : :class:`matplotlib.figure.Figure`
492        The figure created inside which the axes reside
493    tr : list of list of :class:`matplotlib.axes.Axes` objects
494        This is a list, where the inner list is along the x-axis and the outer
495        is along the y-axis
496    cbars : list of :class:`matplotlib.axes.Axes` objects
497        Each of these is an axes onto which a colorbar can be placed.
498
499    Notes
500    -----
501    This is a simple implementation for a common use case.  Viewing the source
502    can be instructive, and is encouraged to see how to generate more
503    complicated or more specific sets of multiplots for your own purposes.
504    """
505    import matplotlib.figure
506
507    hf, wf = 1.0 / ny, 1.0 / nx
508    fudge_x = fudge_y = 1.0
509    if colorbar is None:
510        fudge_x = fudge_y = 1.0
511    elif colorbar.lower() == "vertical":
512        fudge_x = nx / (cbar_padding + nx)
513        fudge_y = 1.0
514    elif colorbar.lower() == "horizontal":
515        fudge_x = 1.0
516        fudge_y = ny / (cbar_padding + ny)
517    fig = matplotlib.figure.Figure((bw * nx / fudge_x, bw * ny / fudge_y), dpi=dpi)
518    from ._mpl_imports import FigureCanvasAgg
519
520    fig.set_canvas(FigureCanvasAgg(fig))
521    fig.subplots_adjust(
522        wspace=0.0, hspace=0.0, top=1.0, bottom=0.0, left=0.0, right=1.0
523    )
524    tr = []
525    for j in range(ny):
526        tr.append([])
527        for i in range(nx):
528            left = i * wf * fudge_x
529            bottom = fudge_y * (1.0 - (j + 1) * hf) + (1.0 - fudge_y)
530            ax = fig.add_axes([left, bottom, wf * fudge_x, hf * fudge_y])
531            tr[-1].append(ax)
532    cbars = []
533    if colorbar is None:
534        pass
535    elif colorbar.lower() == "horizontal":
536        for i in range(nx):
537            # left, bottom, width, height
538            # Here we want 0.10 on each side of the colorbar
539            # We want it to be 0.05 tall
540            # And we want a buffer of 0.15
541            ax = fig.add_axes(
542                [
543                    wf * (i + 0.10) * fudge_x,
544                    hf * fudge_y * 0.20,
545                    wf * (1 - 0.20) * fudge_x,
546                    hf * fudge_y * 0.05,
547                ]
548            )
549            cbars.append(ax)
550    elif colorbar.lower() == "vertical":
551        for j in range(ny):
552            ax = fig.add_axes(
553                [
554                    wf * (nx + 0.05) * fudge_x,
555                    hf * fudge_y * (ny - (j + 0.95)),
556                    wf * fudge_x * 0.05,
557                    hf * fudge_y * 0.90,
558                ]
559            )
560            ax.clear()
561            cbars.append(ax)
562    return fig, tr, cbars
563