1from io import BytesIO 2 3import matplotlib 4import numpy as np 5from packaging.version import parse as parse_version 6 7from yt.funcs import ( 8 get_brewer_cmap, 9 get_interactivity, 10 is_sequence, 11 matplotlib_style_context, 12 mylog, 13) 14 15from ._commons import get_canvas, validate_image_name 16 17BACKEND_SPECS = { 18 "GTK": ["backend_gtk", "FigureCanvasGTK", "FigureManagerGTK"], 19 "GTKAgg": ["backend_gtkagg", "FigureCanvasGTKAgg", None], 20 "GTKCairo": ["backend_gtkcairo", "FigureCanvasGTKCairo", None], 21 "MacOSX": ["backend_macosx", "FigureCanvasMac", "FigureManagerMac"], 22 "Qt4Agg": ["backend_qt4agg", "FigureCanvasQTAgg", None], 23 "Qt5Agg": ["backend_qt5agg", "FigureCanvasQTAgg", None], 24 "TkAgg": ["backend_tkagg", "FigureCanvasTkAgg", None], 25 "WX": ["backend_wx", "FigureCanvasWx", None], 26 "WXAgg": ["backend_wxagg", "FigureCanvasWxAgg", None], 27 "GTK3Cairo": [ 28 "backend_gtk3cairo", 29 "FigureCanvasGTK3Cairo", 30 "FigureManagerGTK3Cairo", 31 ], 32 "GTK3Agg": ["backend_gtk3agg", "FigureCanvasGTK3Agg", "FigureManagerGTK3Agg"], 33 "WebAgg": ["backend_webagg", "FigureCanvasWebAgg", None], 34 "nbAgg": ["backend_nbagg", "FigureCanvasNbAgg", "FigureManagerNbAgg"], 35 "agg": ["backend_agg", "FigureCanvasAgg", None], 36} 37 38 39class CallbackWrapper: 40 def __init__(self, viewer, window_plot, frb, field, font_properties, font_color): 41 self.frb = frb 42 self.data = frb.data_source 43 self._axes = window_plot.axes 44 self._figure = window_plot.figure 45 if len(self._axes.images) > 0: 46 self.image = self._axes.images[0] 47 if frb.axis < 3: 48 DD = frb.ds.domain_width 49 xax = frb.ds.coordinates.x_axis[frb.axis] 50 yax = frb.ds.coordinates.y_axis[frb.axis] 51 self._period = (DD[xax], DD[yax]) 52 self.ds = frb.ds 53 self.xlim = viewer.xlim 54 self.ylim = viewer.ylim 55 self._axes_unit_names = viewer._axes_unit_names 56 if "OffAxisSlice" in viewer._plot_type: 57 self._type_name = "CuttingPlane" 58 else: 59 self._type_name = viewer._plot_type 60 self.aspect = window_plot._aspect 61 self.font_properties = font_properties 62 self.font_color = font_color 63 self.field = field 64 65 66class PlotMPL: 67 """A base class for all yt plots made using matplotlib, that is backend independent.""" 68 69 def __init__(self, fsize, axrect, figure, axes): 70 """Initialize PlotMPL class""" 71 import matplotlib.figure 72 73 self._plot_valid = True 74 if figure is None: 75 if not is_sequence(fsize): 76 fsize = (fsize, fsize) 77 self.figure = matplotlib.figure.Figure(figsize=fsize, frameon=True) 78 else: 79 figure.set_size_inches(fsize) 80 self.figure = figure 81 if axes is None: 82 self._create_axes(axrect) 83 else: 84 axes.cla() 85 axes.set_position(axrect) 86 self.axes = axes 87 self.interactivity = get_interactivity() 88 89 figure_canvas, figure_manager = self._get_canvas_classes() 90 self.canvas = figure_canvas(self.figure) 91 if figure_manager is not None: 92 self.manager = figure_manager(self.canvas, 1) 93 94 for which in ["major", "minor"]: 95 for axis in "xy": 96 self.axes.tick_params( 97 which=which, axis=axis, direction="in", top=True, right=True 98 ) 99 100 def _create_axes(self, axrect): 101 self.axes = self.figure.add_axes(axrect) 102 103 def _get_canvas_classes(self): 104 105 if self.interactivity: 106 key = str(matplotlib.get_backend()) 107 else: 108 key = "agg" 109 110 try: 111 module, fig_canvas, fig_manager = BACKEND_SPECS[key] 112 except KeyError: 113 return 114 115 mod = __import__( 116 "matplotlib.backends", 117 globals(), 118 locals(), 119 [module], 120 0, 121 ) 122 submod = getattr(mod, module) 123 FigureCanvas = getattr(submod, fig_canvas) 124 if fig_manager is not None: 125 FigureManager = getattr(submod, fig_manager) 126 return FigureCanvas, FigureManager 127 128 return FigureCanvas, None 129 130 def save(self, name, mpl_kwargs=None, canvas=None): 131 """Choose backend and save image to disk""" 132 133 if mpl_kwargs is None: 134 mpl_kwargs = {} 135 if "papertype" not in mpl_kwargs and parse_version( 136 matplotlib.__version__ 137 ) < parse_version("3.3.0"): 138 mpl_kwargs["papertype"] = "auto" 139 140 name = validate_image_name(name) 141 142 try: 143 canvas = get_canvas(self.figure, name) 144 except ValueError: 145 canvas = self.canvas 146 147 mylog.info("Saving plot %s", name) 148 with matplotlib_style_context(): 149 canvas.print_figure(name, **mpl_kwargs) 150 return name 151 152 def show(self): 153 try: 154 self.manager.show() 155 except AttributeError: 156 self.canvas.show() 157 158 def _get_labels(self): 159 ax = self.axes 160 labels = ax.xaxis.get_ticklabels() + ax.yaxis.get_ticklabels() 161 labels += ax.xaxis.get_minorticklabels() 162 labels += ax.yaxis.get_minorticklabels() 163 labels += [ 164 ax.title, 165 ax.xaxis.label, 166 ax.yaxis.label, 167 ax.xaxis.get_offset_text(), 168 ax.yaxis.get_offset_text(), 169 ] 170 return labels 171 172 def _set_font_properties(self, font_properties, font_color): 173 for label in self._get_labels(): 174 label.set_fontproperties(font_properties) 175 if font_color is not None: 176 label.set_color(self.font_color) 177 178 def _repr_png_(self): 179 from ._mpl_imports import FigureCanvasAgg 180 181 canvas = FigureCanvasAgg(self.figure) 182 f = BytesIO() 183 with matplotlib_style_context(): 184 canvas.print_figure(f) 185 f.seek(0) 186 return f.read() 187 188 189class ImagePlotMPL(PlotMPL): 190 """A base class for yt plots made using imshow""" 191 192 def __init__(self, fsize, axrect, caxrect, zlim, figure, axes, cax): 193 """Initialize ImagePlotMPL class object""" 194 super().__init__(fsize, axrect, figure, axes) 195 self.zmin, self.zmax = zlim 196 if cax is None: 197 self.cax = self.figure.add_axes(caxrect) 198 else: 199 cax.cla() 200 cax.set_position(caxrect) 201 self.cax = cax 202 203 def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): 204 """Store output of imshow in image variable""" 205 cbnorm_kwargs = dict( 206 vmin=float(self.zmin) if self.zmin is not None else None, 207 vmax=float(self.zmax) if self.zmax is not None else None, 208 ) 209 if cbnorm == "log10": 210 cbnorm_cls = matplotlib.colors.LogNorm 211 elif cbnorm == "linear": 212 cbnorm_cls = matplotlib.colors.Normalize 213 elif cbnorm == "symlog": 214 # if cblinthresh is not specified, try to come up with a reasonable default 215 vmin = float(np.nanmin(data)) 216 vmax = float(np.nanmax(data)) 217 if cblinthresh is None: 218 cblinthresh = np.nanmin(np.absolute(data)[data != 0]) 219 220 cbnorm_kwargs.update(dict(linthresh=cblinthresh, vmin=vmin, vmax=vmax)) 221 MPL_VERSION = parse_version(matplotlib.__version__) 222 if MPL_VERSION >= parse_version("3.2.0"): 223 # note that this creates an inconsistency between mpl versions 224 # since the default value previous to mpl 3.4.0 is np.e 225 # but it is only exposed since 3.2.0 226 cbnorm_kwargs["base"] = 10 227 228 cbnorm_cls = matplotlib.colors.SymLogNorm 229 else: 230 raise ValueError(f"Unknown value `cbnorm` == {cbnorm}") 231 232 norm = cbnorm_cls(**cbnorm_kwargs) 233 234 extent = [float(e) for e in extent] 235 # tuple colormaps are from palettable (or brewer2mpl) 236 if isinstance(cmap, tuple): 237 cmap = get_brewer_cmap(cmap) 238 239 if self._transform is None: 240 # sets the transform to be an ax.TransData object, where the 241 # coordiante system of the data is controlled by the xlim and ylim 242 # of the data. 243 transform = self.axes.transData 244 else: 245 transform = self._transform 246 if hasattr(self.axes, "set_extent"): 247 # CartoPy hangs if we do not set_extent before imshow if we are 248 # displaying a small subset of the globe. What I believe happens is 249 # that the transform for the points on the outside results in 250 # infinities, and then the scipy.spatial cKDTree hangs trying to 251 # identify nearest points. 252 # 253 # Also, set_extent is defined by cartopy, so not all axes will have 254 # it as a method. 255 # 256 # A potential downside is that other images may change, but I believe 257 # the result of imshow is to set_extent *regardless*. This just 258 # changes the order in which it happens. 259 # 260 # NOTE: This is currently commented out because it breaks in some 261 # instances. It is left as a historical note because we will 262 # eventually need some form of it. 263 # self.axes.set_extent(extent) 264 pass 265 self.image = self.axes.imshow( 266 data.to_ndarray(), 267 origin="lower", 268 extent=extent, 269 norm=norm, 270 aspect=aspect, 271 cmap=cmap, 272 interpolation="nearest", 273 transform=transform, 274 ) 275 if cbnorm == "symlog": 276 formatter = matplotlib.ticker.LogFormatterMathtext(linthresh=cblinthresh) 277 self.cb = self.figure.colorbar(self.image, self.cax, format=formatter) 278 if np.nanmin(data) >= 0.0: 279 yticks = [np.nanmin(data).v] + list( 280 10 281 ** np.arange( 282 np.rint(np.log10(cblinthresh)), 283 np.ceil(np.log10(np.nanmax(data))) + 1, 284 ) 285 ) 286 elif np.nanmax(data) <= 0.0: 287 yticks = ( 288 list( 289 -( 290 10 291 ** np.arange( 292 np.floor(np.log10(-np.nanmin(data))), 293 np.rint(np.log10(cblinthresh)) - 1, 294 -1, 295 ) 296 ) 297 ) 298 + [np.nanmax(data).v] 299 ) 300 else: 301 yticks = ( 302 list( 303 -( 304 10 305 ** np.arange( 306 np.floor(np.log10(-np.nanmin(data))), 307 np.rint(np.log10(cblinthresh)) - 1, 308 -1, 309 ) 310 ) 311 ) 312 + [0] 313 + list( 314 10 315 ** np.arange( 316 np.rint(np.log10(cblinthresh)), 317 np.ceil(np.log10(np.nanmax(data))) + 1, 318 ) 319 ) 320 ) 321 self.cb.set_ticks(yticks) 322 else: 323 self.cb = self.figure.colorbar(self.image, self.cax) 324 for which in ["major", "minor"]: 325 self.cax.tick_params(which=which, axis="y", direction="in") 326 327 def _get_best_layout(self): 328 329 # Ensure the figure size along the long axis is always equal to _figure_size 330 if is_sequence(self._figure_size): 331 x_fig_size = self._figure_size[0] 332 y_fig_size = self._figure_size[1] 333 else: 334 x_fig_size = self._figure_size 335 y_fig_size = self._figure_size / self._aspect 336 337 if hasattr(self, "_unit_aspect"): 338 y_fig_size = y_fig_size * self._unit_aspect 339 340 if self._draw_colorbar: 341 cb_size = self._cb_size 342 cb_text_size = self._ax_text_size[1] + 0.45 343 else: 344 cb_size = x_fig_size * 0.04 345 cb_text_size = 0.0 346 347 if self._draw_axes: 348 x_axis_size = self._ax_text_size[0] 349 y_axis_size = self._ax_text_size[1] 350 else: 351 x_axis_size = x_fig_size * 0.04 352 y_axis_size = y_fig_size * 0.04 353 354 top_buff_size = self._top_buff_size 355 356 if not self._draw_axes and not self._draw_colorbar: 357 x_axis_size = 0.0 358 y_axis_size = 0.0 359 cb_size = 0.0 360 cb_text_size = 0.0 361 top_buff_size = 0.0 362 363 xbins = np.array([x_axis_size, x_fig_size, cb_size, cb_text_size]) 364 ybins = np.array([y_axis_size, y_fig_size, top_buff_size]) 365 366 size = [xbins.sum(), ybins.sum()] 367 368 x_frac_widths = xbins / size[0] 369 y_frac_widths = ybins / size[1] 370 371 # axrect is the rectangle defining the area of the 372 # axis object of the plot. Its range goes from 0 to 1 in 373 # x and y directions. The first two values are the x,y 374 # start values of the axis object (lower left corner), and the 375 # second two values are the size of the axis object. To get 376 # the upper right corner, add the first x,y to the second x,y. 377 axrect = ( 378 x_frac_widths[0], 379 y_frac_widths[0], 380 x_frac_widths[1], 381 y_frac_widths[1], 382 ) 383 384 # caxrect is the rectangle defining the area of the colorbar 385 # axis object of the plot. It is defined just as the axrect 386 # tuple is. 387 caxrect = ( 388 x_frac_widths[0] + x_frac_widths[1], 389 y_frac_widths[0], 390 x_frac_widths[2], 391 y_frac_widths[1], 392 ) 393 394 return size, axrect, caxrect 395 396 def _toggle_axes(self, choice, draw_frame=None): 397 """ 398 Turn on/off displaying the axis ticks and labels for a plot. 399 400 Parameters 401 ---------- 402 choice : boolean 403 If True, set the axes to be drawn. If False, set the axes to not be 404 drawn. 405 """ 406 if draw_frame is None: 407 draw_frame = choice 408 self._draw_axes = choice 409 self._draw_frame = draw_frame 410 self.axes.set_frame_on(draw_frame) 411 self.axes.get_xaxis().set_visible(choice) 412 self.axes.get_yaxis().set_visible(choice) 413 size, axrect, caxrect = self._get_best_layout() 414 self.axes.set_position(axrect) 415 self.cax.set_position(caxrect) 416 self.figure.set_size_inches(*size) 417 418 def _toggle_colorbar(self, choice): 419 """ 420 Turn on/off displaying the colorbar for a plot 421 422 choice = True or False 423 """ 424 self._draw_colorbar = choice 425 self.cax.set_visible(choice) 426 size, axrect, caxrect = self._get_best_layout() 427 self.axes.set_position(axrect) 428 self.cax.set_position(caxrect) 429 self.figure.set_size_inches(*size) 430 431 def _get_labels(self): 432 labels = super()._get_labels() 433 cbax = self.cb.ax 434 labels += cbax.yaxis.get_ticklabels() 435 labels += [cbax.yaxis.label, cbax.yaxis.get_offset_text()] 436 return labels 437 438 def hide_axes(self, draw_frame=None): 439 """ 440 Hide the axes for a plot including ticks and labels 441 """ 442 self._toggle_axes(False, draw_frame) 443 return self 444 445 def show_axes(self): 446 """ 447 Show the axes for a plot including ticks and labels 448 """ 449 self._toggle_axes(True) 450 return self 451 452 def hide_colorbar(self): 453 """ 454 Hide the colorbar for a plot including ticks and labels 455 """ 456 self._toggle_colorbar(False) 457 return self 458 459 def show_colorbar(self): 460 """ 461 Show the colorbar for a plot including ticks and labels 462 """ 463 self._toggle_colorbar(True) 464 return self 465 466 467def get_multi_plot(nx, ny, colorbar="vertical", bw=4, dpi=300, cbar_padding=0.4): 468 r"""Construct a multiple axes plot object, with or without a colorbar, into 469 which multiple plots may be inserted. 470 471 This will create a set of :class:`matplotlib.axes.Axes`, all lined up into 472 a grid, which are then returned to the user and which can be used to plot 473 multiple plots on a single figure. 474 475 Parameters 476 ---------- 477 nx : int 478 Number of axes to create along the x-direction 479 ny : int 480 Number of axes to create along the y-direction 481 colorbar : {'vertical', 'horizontal', None}, optional 482 Should Axes objects for colorbars be allocated, and if so, should they 483 correspond to the horizontal or vertical set of axes? 484 bw : number 485 The base height/width of an axes object inside the figure, in inches 486 dpi : number 487 The dots per inch fed into the Figure instantiation 488 489 Returns 490 ------- 491 fig : :class:`matplotlib.figure.Figure` 492 The figure created inside which the axes reside 493 tr : list of list of :class:`matplotlib.axes.Axes` objects 494 This is a list, where the inner list is along the x-axis and the outer 495 is along the y-axis 496 cbars : list of :class:`matplotlib.axes.Axes` objects 497 Each of these is an axes onto which a colorbar can be placed. 498 499 Notes 500 ----- 501 This is a simple implementation for a common use case. Viewing the source 502 can be instructive, and is encouraged to see how to generate more 503 complicated or more specific sets of multiplots for your own purposes. 504 """ 505 import matplotlib.figure 506 507 hf, wf = 1.0 / ny, 1.0 / nx 508 fudge_x = fudge_y = 1.0 509 if colorbar is None: 510 fudge_x = fudge_y = 1.0 511 elif colorbar.lower() == "vertical": 512 fudge_x = nx / (cbar_padding + nx) 513 fudge_y = 1.0 514 elif colorbar.lower() == "horizontal": 515 fudge_x = 1.0 516 fudge_y = ny / (cbar_padding + ny) 517 fig = matplotlib.figure.Figure((bw * nx / fudge_x, bw * ny / fudge_y), dpi=dpi) 518 from ._mpl_imports import FigureCanvasAgg 519 520 fig.set_canvas(FigureCanvasAgg(fig)) 521 fig.subplots_adjust( 522 wspace=0.0, hspace=0.0, top=1.0, bottom=0.0, left=0.0, right=1.0 523 ) 524 tr = [] 525 for j in range(ny): 526 tr.append([]) 527 for i in range(nx): 528 left = i * wf * fudge_x 529 bottom = fudge_y * (1.0 - (j + 1) * hf) + (1.0 - fudge_y) 530 ax = fig.add_axes([left, bottom, wf * fudge_x, hf * fudge_y]) 531 tr[-1].append(ax) 532 cbars = [] 533 if colorbar is None: 534 pass 535 elif colorbar.lower() == "horizontal": 536 for i in range(nx): 537 # left, bottom, width, height 538 # Here we want 0.10 on each side of the colorbar 539 # We want it to be 0.05 tall 540 # And we want a buffer of 0.15 541 ax = fig.add_axes( 542 [ 543 wf * (i + 0.10) * fudge_x, 544 hf * fudge_y * 0.20, 545 wf * (1 - 0.20) * fudge_x, 546 hf * fudge_y * 0.05, 547 ] 548 ) 549 cbars.append(ax) 550 elif colorbar.lower() == "vertical": 551 for j in range(ny): 552 ax = fig.add_axes( 553 [ 554 wf * (nx + 0.05) * fudge_x, 555 hf * fudge_y * (ny - (j + 0.95)), 556 wf * fudge_x * 0.05, 557 hf * fudge_y * 0.90, 558 ] 559 ) 560 ax.clear() 561 cbars.append(ax) 562 return fig, tr, cbars 563