1import re 2import warnings 3from functools import wraps 4from numbers import Number 5 6import matplotlib 7import numpy as np 8 9from yt.data_objects.data_containers import YTDataContainer 10from yt.data_objects.level_sets.clump_handling import Clump 11from yt.data_objects.selection_objects.cut_region import YTCutRegion 12from yt.data_objects.static_output import Dataset 13from yt.frontends.ytdata.data_structures import YTClumpContainer 14from yt.funcs import is_sequence, mylog, validate_width_tuple 15from yt.geometry.geometry_handler import is_curvilinear 16from yt.geometry.unstructured_mesh_handler import UnstructuredIndex 17from yt.units import dimensions 18from yt.units.yt_array import YTArray, YTQuantity, uhstack 19from yt.utilities.exceptions import YTDataTypeUnsupported 20from yt.utilities.lib.geometry_utils import triangle_plane_intersect 21from yt.utilities.lib.line_integral_convolution import line_integral_convolution_2d 22from yt.utilities.lib.mesh_triangulation import triangulate_indices 23from yt.utilities.lib.pixelization_routines import ( 24 pixelize_cartesian, 25 pixelize_off_axis_cartesian, 26) 27from yt.utilities.math_utils import periodic_ray 28from yt.utilities.on_demand_imports import NotAModule 29from yt.visualization.image_writer import apply_colormap 30 31callback_registry = {} 32 33 34def _verify_geometry(func): 35 @wraps(func) 36 def _check_geometry(self, plot): 37 geom = plot.data.ds.coordinates.name 38 supp = self._supported_geometries 39 cs = getattr(self, "coord_system", None) 40 if supp is None or geom in supp: 41 return func(self, plot) 42 if cs in ("axis", "figure") and "force" not in supp: 43 return func(self, plot) 44 raise YTDataTypeUnsupported(geom, supp) 45 46 return _check_geometry 47 48 49class PlotCallback: 50 # _supported_geometries is set by subclasses of PlotCallback to a tuple of 51 # strings corresponding to the names of the geometries that a callback 52 # supports. By default it is None, which means it supports everything. 53 # Note that if there's a coord_system parameter that is set to "axis" or 54 # "figure" this is disregarded. If "force" is included in the tuple, it 55 # will *not* check whether or not the coord_system is in axis or figure, 56 # and will only look at the geometries. 57 _supported_geometries = None 58 59 def __init_subclass__(cls, *args, **kwargs): 60 super().__init_subclass__(*args, **kwargs) 61 callback_registry[cls.__name__] = cls 62 cls.__call__ = _verify_geometry(cls.__call__) 63 64 def __init__(self, *args, **kwargs): 65 pass 66 67 def __call__(self, plot): 68 raise NotImplementedError 69 70 def _project_coords(self, plot, coord): 71 """ 72 Convert coordinates from simulation data coordinates to projected 73 data coordinates. Simulation data coordinates are three dimensional, 74 and can either be specified as a YTArray or as a list or array in 75 code_length units. Projected data units are 2D versions of the 76 simulation data units relative to the axes of the final plot. 77 """ 78 if len(coord) == 3: 79 if not isinstance(coord, YTArray): 80 coord = plot.data.ds.arr(coord, "code_length") 81 coord.convert_to_units("code_length") 82 ax = plot.data.axis 83 # if this is an on-axis projection or slice, then 84 # just grab the appropriate 2 coords for the on-axis view 85 if ax >= 0 and ax <= 2: 86 (xi, yi) = ( 87 plot.data.ds.coordinates.x_axis[ax], 88 plot.data.ds.coordinates.y_axis[ax], 89 ) 90 coord = (coord[xi], coord[yi]) 91 92 # if this is an off-axis project or slice (ie cutting plane) 93 # we have to calculate where the data coords fall in the projected 94 # plane 95 elif ax == 4: 96 # transpose is just to get [[x1,x2,...],[y1,y2,...],[z1,z2,...]] 97 # in the same order as plot.data.center for array arithmetic 98 coord_vectors = coord.transpose() - plot.data.center 99 x = np.dot(coord_vectors, plot.data.orienter.unit_vectors[1]) 100 y = np.dot(coord_vectors, plot.data.orienter.unit_vectors[0]) 101 # Transpose into image coords. Due to VR being not a 102 # right-handed coord system 103 coord = (y, x) 104 else: 105 raise ValueError("Object being plot must have a `data.axis` defined") 106 107 # if the position is already two-coords, it is expected to be 108 # in the proper projected orientation 109 else: 110 raise ValueError("'data' coordinates must be 3 dimensions") 111 return coord 112 113 def _convert_to_plot(self, plot, coord, offset=True): 114 """ 115 Convert coordinates from projected data coordinates to PlotWindow 116 plot coordinates. Projected data coordinates are two dimensional 117 and refer to the location relative to the specific axes being plotted, 118 although still in simulation units. PlotWindow plot coordinates 119 are locations as found in the final plot, usually with the origin 120 in the center of the image and the extent of the image defined by 121 the final plot axis markers. 122 """ 123 # coord should be a 2 x ncoord array-like datatype. 124 try: 125 ncoord = np.array(coord).shape[1] 126 except IndexError: 127 ncoord = 1 128 129 # Convert the data and plot limits to tiled numpy arrays so that 130 # convert_to_plot is automatically vectorized. 131 132 x0 = np.array(np.tile(plot.xlim[0].to("code_length"), ncoord)) 133 x1 = np.array(np.tile(plot.xlim[1].to("code_length"), ncoord)) 134 xx0 = np.tile(plot._axes.get_xlim()[0], ncoord) 135 xx1 = np.tile(plot._axes.get_xlim()[1], ncoord) 136 137 y0 = np.array(np.tile(plot.ylim[0].to("code_length"), ncoord)) 138 y1 = np.array(np.tile(plot.ylim[1].to("code_length"), ncoord)) 139 yy0 = np.tile(plot._axes.get_ylim()[0], ncoord) 140 yy1 = np.tile(plot._axes.get_ylim()[1], ncoord) 141 142 try: 143 ccoord = np.array(coord.to("code_length")) 144 except AttributeError: 145 ccoord = np.array(coord) 146 147 # We need a special case for when we are only given one coordinate. 148 if ccoord.shape == (2,): 149 return np.array( 150 [ 151 ((ccoord[0] - x0) / (x1 - x0) * (xx1 - xx0) + xx0)[0], 152 ((ccoord[1] - y0) / (y1 - y0) * (yy1 - yy0) + yy0)[0], 153 ] 154 ) 155 else: 156 return np.array( 157 [ 158 (ccoord[0][:] - x0) / (x1 - x0) * (xx1 - xx0) + xx0, 159 (ccoord[1][:] - y0) / (y1 - y0) * (yy1 - yy0) + yy0, 160 ] 161 ) 162 163 def _sanitize_coord_system(self, plot, coord, coord_system): 164 """ 165 Given a set of one or more x,y (and z) coordinates and a coordinate 166 system, convert the coordinates (and transformation) ready for final 167 plotting. 168 169 Parameters 170 ---------- 171 172 plot: a PlotMPL subclass 173 The plot that we are converting coordinates for 174 175 coord: array-like 176 Coordinates in some coordinate system: [x,y,z]. 177 Alternatively, can specify multiple coordinates as: 178 [[x1,x2,...,xn], [y1, y2,...,yn], [z1,z2,...,zn]] 179 180 coord_system: string 181 182 Possible values include: 183 184 * ``'data'`` 185 3D data coordinates relative to original dataset 186 187 * ``'plot'`` 188 2D coordinates as defined by the final axis locations 189 190 * ``'axis'`` 191 2D coordinates within the axis object from (0,0) in lower left 192 to (1,1) in upper right. Same as matplotlib axis coords. 193 194 * ``'figure'`` 195 2D coordinates within figure object from (0,0) in lower left 196 to (1,1) in upper right. Same as matplotlib figure coords. 197 """ 198 # Assure coords are either a YTArray or numpy array 199 coord = np.asanyarray(coord, dtype="float64") 200 # if in data coords, project them to plot coords 201 if coord_system == "data": 202 if len(coord) < 3: 203 raise ValueError( 204 "Coordinates in 'data' coordinate system need to be in 3D" 205 ) 206 coord = self._project_coords(plot, coord) 207 coord = self._convert_to_plot(plot, coord) 208 # if in plot coords, define the transform correctly 209 if coord_system == "data" or coord_system == "plot": 210 self.transform = plot._axes.transData 211 return coord 212 # if in axis coords, define the transform correctly 213 if coord_system == "axis": 214 self.transform = plot._axes.transAxes 215 if len(coord) > 2: 216 raise ValueError( 217 "Coordinates in 'axis' coordinate system need to be in 2D" 218 ) 219 return coord 220 # if in figure coords, define the transform correctly 221 elif coord_system == "figure": 222 self.transform = plot._figure.transFigure 223 return coord 224 else: 225 raise ValueError( 226 "Argument coord_system must have a value of " 227 "'data', 'plot', 'axis', or 'figure'." 228 ) 229 230 def _physical_bounds(self, plot): 231 xlims = tuple(v.in_units("code_length") for v in plot.xlim) 232 ylims = tuple(v.in_units("code_length") for v in plot.ylim) 233 return xlims + ylims 234 235 def _plot_bounds(self, plot): 236 return plot._axes.get_xlim() + plot._axes.get_ylim() 237 238 def _pixel_scale(self, plot): 239 x0, x1, y0, y1 = self._physical_bounds(plot) 240 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 241 dx = (xx1 - xx0) / (x1 - x0) 242 dy = (yy1 - yy0) / (y1 - y0) 243 return dx, dy 244 245 def _set_font_properties(self, plot, labels, **kwargs): 246 """ 247 This sets all of the text instances created by a callback to have 248 the same font size and properties as all of the other fonts in the 249 figure. If kwargs are set, they override the defaults. 250 """ 251 # This is a little messy because there is no trivial way to update 252 # a MPL.font_manager.FontProperties object with new attributes 253 # aside from setting them individually. So we pick out the relevant 254 # MPL.Text() kwargs from the local kwargs and let them override the 255 # defaults. 256 local_font_properties = plot.font_properties.copy() 257 258 # Turn off the default TT font file, otherwise none of this works. 259 local_font_properties.set_file(None) 260 local_font_properties.set_family("stixgeneral") 261 262 if "family" in kwargs: 263 local_font_properties.set_family(kwargs["family"]) 264 if "file" in kwargs: 265 local_font_properties.set_file(kwargs["file"]) 266 if "fontconfig_pattern" in kwargs: 267 local_font_properties.set_fontconfig_pattern(kwargs["fontconfig_pattern"]) 268 if "name" in kwargs: 269 local_font_properties.set_name(kwargs["name"]) 270 if "size" in kwargs: 271 local_font_properties.set_size(kwargs["size"]) 272 if "slant" in kwargs: 273 local_font_properties.set_slant(kwargs["slant"]) 274 if "stretch" in kwargs: 275 local_font_properties.set_stretch(kwargs["stretch"]) 276 if "style" in kwargs: 277 local_font_properties.set_style(kwargs["style"]) 278 if "variant" in kwargs: 279 local_font_properties.set_variant(kwargs["variant"]) 280 if "weight" in kwargs: 281 local_font_properties.set_weight(kwargs["weight"]) 282 283 # For each label, set the font properties and color to the figure 284 # defaults if not already set in the callback itself 285 for label in labels: 286 if plot.font_color is not None and "color" not in kwargs: 287 label.set_color(plot.font_color) 288 label.set_fontproperties(local_font_properties) 289 290 291class VelocityCallback(PlotCallback): 292 """ 293 Adds a 'quiver' plot of velocity to the plot, skipping all but 294 every *factor* datapoint. *scale* is the data units per arrow 295 length unit using *scale_units* and *plot_args* allows you to 296 pass in matplotlib arguments (see matplotlib.axes.Axes.quiver 297 for more info). if *normalize* is True, the velocity fields 298 will be scaled by their local (in-plane) length, allowing 299 morphological features to be more clearly seen for fields 300 with substantial variation in field strength. 301 """ 302 303 _type_name = "velocity" 304 _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") 305 306 def __init__( 307 self, factor=16, scale=None, scale_units=None, normalize=False, plot_args=None 308 ): 309 PlotCallback.__init__(self) 310 self.factor = factor 311 self.scale = scale 312 self.scale_units = scale_units 313 self.normalize = normalize 314 if plot_args is None: 315 plot_args = {} 316 self.plot_args = plot_args 317 318 def __call__(self, plot): 319 ftype = plot.data._current_fluid_type 320 # Instantiation of these is cheap 321 if plot._type_name == "CuttingPlane": 322 if is_curvilinear(plot.data.ds.geometry): 323 raise NotImplementedError( 324 "Velocity annotation for cutting \ 325 plane is not supported for %s geometry" 326 % plot.data.ds.geometry 327 ) 328 if plot._type_name == "CuttingPlane": 329 qcb = CuttingQuiverCallback( 330 (ftype, "cutting_plane_velocity_x"), 331 (ftype, "cutting_plane_velocity_y"), 332 self.factor, 333 scale=self.scale, 334 normalize=self.normalize, 335 scale_units=self.scale_units, 336 plot_args=self.plot_args, 337 ) 338 else: 339 xax = plot.data.ds.coordinates.x_axis[plot.data.axis] 340 yax = plot.data.ds.coordinates.y_axis[plot.data.axis] 341 axis_names = plot.data.ds.coordinates.axis_name 342 343 bv = plot.data.get_field_parameter("bulk_velocity") 344 if bv is not None: 345 bv_x = bv[xax] 346 bv_y = bv[yax] 347 else: 348 bv_x = bv_y = 0 349 350 if ( 351 plot.data.ds.geometry in ["polar", "cylindrical"] 352 and axis_names[plot.data.axis] == "z" 353 ): 354 # polar_z and cyl_z is aligned with carteian_z 355 # should convert r-theta plane to x-y plane 356 xv = (ftype, "velocity_cartesian_x") 357 yv = (ftype, "velocity_cartesian_y") 358 else: 359 # for other cases (even for cylindrical geometry), 360 # orthogonal planes are generically Cartesian 361 xv = (ftype, f"velocity_{axis_names[xax]}") 362 yv = (ftype, f"velocity_{axis_names[yax]}") 363 364 # determine the full fields including field type 365 xv = plot.data._determine_fields(xv)[0] 366 yv = plot.data._determine_fields(yv)[0] 367 368 qcb = QuiverCallback( 369 xv, 370 yv, 371 self.factor, 372 scale=self.scale, 373 scale_units=self.scale_units, 374 normalize=self.normalize, 375 bv_x=bv_x, 376 bv_y=bv_y, 377 plot_args=self.plot_args, 378 ) 379 return qcb(plot) 380 381 382class MagFieldCallback(PlotCallback): 383 """ 384 Adds a 'quiver' plot of magnetic field to the plot, skipping all but 385 every *factor* datapoint. *scale* is the data units per arrow 386 length unit using *scale_units* and *plot_args* allows you to pass 387 in matplotlib arguments (see matplotlib.axes.Axes.quiver for more info). 388 if *normalize* is True, the magnetic fields will be scaled by their 389 local (in-plane) length, allowing morphological features to be more 390 clearly seen for fields with substantial variation in field strength. 391 """ 392 393 _type_name = "magnetic_field" 394 _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") 395 396 def __init__( 397 self, factor=16, scale=None, scale_units=None, normalize=False, plot_args=None 398 ): 399 PlotCallback.__init__(self) 400 self.factor = factor 401 self.scale = scale 402 self.scale_units = scale_units 403 self.normalize = normalize 404 if plot_args is None: 405 plot_args = {} 406 self.plot_args = plot_args 407 408 def __call__(self, plot): 409 ftype = plot.data._current_fluid_type 410 # Instantiation of these is cheap 411 if plot._type_name == "CuttingPlane": 412 if is_curvilinear(plot.data.ds.geometry): 413 raise NotImplementedError( 414 "Magnetic field annotation for cutting \ 415 plane is not supported for %s geometry" 416 % plot.data.ds.geometry 417 ) 418 qcb = CuttingQuiverCallback( 419 (ftype, "cutting_plane_magnetic_field_x"), 420 (ftype, "cutting_plane_magnetic_field_y"), 421 self.factor, 422 scale=self.scale, 423 scale_units=self.scale_units, 424 normalize=self.normalize, 425 plot_args=self.plot_args, 426 ) 427 else: 428 xax = plot.data.ds.coordinates.x_axis[plot.data.axis] 429 yax = plot.data.ds.coordinates.y_axis[plot.data.axis] 430 axis_names = plot.data.ds.coordinates.axis_name 431 432 if ( 433 plot.data.ds.geometry in ["polar", "cylindrical"] 434 and axis_names[plot.data.axis] == "z" 435 ): 436 # polar_z and cyl_z is aligned with carteian_z 437 # should convert r-theta plane to x-y plane 438 xv = (ftype, "magnetic_field_cartesian_x") 439 yv = (ftype, "magnetic_field_cartesian_y") 440 else: 441 # for other cases (even for cylindrical geometry), 442 # orthogonal planes are generically Cartesian 443 xv = (ftype, f"magnetic_field_{axis_names[xax]}") 444 yv = (ftype, f"magnetic_field_{axis_names[yax]}") 445 446 qcb = QuiverCallback( 447 xv, 448 yv, 449 self.factor, 450 scale=self.scale, 451 scale_units=self.scale_units, 452 normalize=self.normalize, 453 plot_args=self.plot_args, 454 ) 455 return qcb(plot) 456 457 458class QuiverCallback(PlotCallback): 459 """ 460 Adds a 'quiver' plot to any plot, using the *field_x* and *field_y* 461 from the associated data, skipping every *factor* datapoints. 462 *scale* is the data units per arrow length unit using *scale_units* 463 and *plot_args* allows you to pass in matplotlib arguments (see 464 matplotlib.axes.Axes.quiver for more info). if *normalize* is True, 465 the fields will be scaled by their local (in-plane) length, allowing 466 morphological features to be more clearly seen for fields with 467 substantial variation in field strength. 468 """ 469 470 _type_name = "quiver" 471 _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") 472 473 def __init__( 474 self, 475 field_x, 476 field_y, 477 factor=16, 478 scale=None, 479 scale_units=None, 480 normalize=False, 481 bv_x=0, 482 bv_y=0, 483 plot_args=None, 484 ): 485 PlotCallback.__init__(self) 486 self.field_x = field_x 487 self.field_y = field_y 488 self.bv_x = bv_x 489 self.bv_y = bv_y 490 self.factor = factor 491 self.scale = scale 492 self.scale_units = scale_units 493 self.normalize = normalize 494 if plot_args is None: 495 plot_args = {} 496 self.plot_args = plot_args 497 498 def __call__(self, plot): 499 x0, x1, y0, y1 = self._physical_bounds(plot) 500 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 501 bounds = [x0, x1, y0, y1] 502 periodic = int(any(plot.data.ds.periodicity)) 503 504 def transform(field_name, vector_value): 505 field_units = plot.data[field_name].units 506 507 def _transformed_field(field, data): 508 return data[field_name] - data.ds.arr(vector_value, field_units) 509 510 plot.data.ds.add_field( 511 ("gas", f"transformed_{field_name}"), 512 sampling_type="cell", 513 function=_transformed_field, 514 units=field_units, 515 display_field=False, 516 ) 517 518 if self.bv_x != 0.0 or self.bv_x != 0.0: 519 # We create a relative vector field 520 transform(self.field_x, self.bv_x) 521 transform(self.field_y, self.bv_y) 522 field_x = f"transformed_{self.field_x}" 523 field_y = f"transformed_{self.field_y}" 524 else: 525 field_x, field_y = self.field_x, self.field_y 526 527 # We are feeding this size into the pixelizer, where it will properly 528 # set it in reverse order 529 nx = plot.image._A.shape[1] // self.factor 530 ny = plot.image._A.shape[0] // self.factor 531 pixX = plot.data.ds.coordinates.pixelize( 532 plot.data.axis, 533 plot.data, 534 field_x, 535 bounds, 536 (nx, ny), 537 False, # antialias 538 periodic, 539 ) 540 pixY = plot.data.ds.coordinates.pixelize( 541 plot.data.axis, 542 plot.data, 543 field_y, 544 bounds, 545 (nx, ny), 546 False, # antialias 547 periodic, 548 ) 549 X, Y = np.meshgrid( 550 np.linspace(xx0, xx1, nx, endpoint=True), 551 np.linspace(yy0, yy1, ny, endpoint=True), 552 ) 553 if self.normalize: 554 nn = np.sqrt(pixX ** 2 + pixY ** 2) 555 pixX /= nn 556 pixY /= nn 557 plot._axes.quiver( 558 X, 559 Y, 560 pixX, 561 pixY, 562 scale=self.scale, 563 scale_units=self.scale_units, 564 **self.plot_args, 565 ) 566 plot._axes.set_xlim(xx0, xx1) 567 plot._axes.set_ylim(yy0, yy1) 568 569 570class ContourCallback(PlotCallback): 571 """ 572 Add contours in *field* to the plot. *ncont* governs the number of 573 contours generated, *factor* governs the number of points used in the 574 interpolation, *take_log* governs how it is contoured and *clim* gives 575 the (upper, lower) limits for contouring. An alternate data source can be 576 specified with *data_source*, but by default the plot's data source will be 577 queried. 578 """ 579 580 _type_name = "contour" 581 _supported_geometries = ("cartesian", "spectral_cube", "cylindrical") 582 583 def __init__( 584 self, 585 field, 586 ncont=5, 587 factor=4, 588 clim=None, 589 plot_args=None, 590 label=False, 591 take_log=None, 592 label_args=None, 593 text_args=None, 594 data_source=None, 595 ): 596 PlotCallback.__init__(self) 597 def_plot_args = {"colors": "k", "linestyles": "solid"} 598 def_text_args = {"colors": "w"} 599 self.ncont = ncont 600 self.field = field 601 self.factor = factor 602 self.clim = clim 603 self.take_log = take_log 604 if plot_args is None: 605 plot_args = def_plot_args 606 self.plot_args = plot_args 607 self.label = label 608 if label_args is not None: 609 text_args = label_args 610 warnings.warn( 611 "The label_args keyword is deprecated. Please use " 612 "the text_args keyword instead." 613 ) 614 if text_args is None: 615 text_args = def_text_args 616 self.text_args = text_args 617 self.data_source = data_source 618 619 def __call__(self, plot): 620 from matplotlib.tri import LinearTriInterpolator, Triangulation 621 622 # These need to be in code_length 623 x0, x1, y0, y1 = self._physical_bounds(plot) 624 625 # These are in plot coordinates, which may not be code coordinates. 626 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 627 628 # See the note about rows/columns in the pixelizer for more information 629 # on why we choose the bounds we do 630 numPoints_x = plot.image._A.shape[1] 631 numPoints_y = plot.image._A.shape[0] 632 633 # Multiply by dx and dy to go from data->plot 634 dx = (xx1 - xx0) / (x1 - x0) 635 dy = (yy1 - yy0) / (y1 - y0) 636 637 # We want xi, yi in plot coordinates 638 xi, yi = np.mgrid[ 639 xx0 : xx1 : numPoints_x / (self.factor * 1j), 640 yy0 : yy1 : numPoints_y / (self.factor * 1j), 641 ] 642 data = self.data_source or plot.data 643 644 if plot._type_name in ["CuttingPlane", "Projection", "Slice"]: 645 if plot._type_name == "CuttingPlane": 646 x = data["px"] * dx 647 y = data["py"] * dy 648 z = data[self.field] 649 elif plot._type_name in ["Projection", "Slice"]: 650 # Makes a copy of the position fields "px" and "py" and adds the 651 # appropriate shift to the copied field. 652 653 AllX = np.zeros(data["px"].size, dtype="bool") 654 AllY = np.zeros(data["py"].size, dtype="bool") 655 XShifted = data["px"].copy() 656 YShifted = data["py"].copy() 657 dom_x, dom_y = plot._period 658 for shift in np.mgrid[-1:1:3j]: 659 xlim = (data["px"] + shift * dom_x >= x0) & ( 660 data["px"] + shift * dom_x <= x1 661 ) 662 ylim = (data["py"] + shift * dom_y >= y0) & ( 663 data["py"] + shift * dom_y <= y1 664 ) 665 XShifted[xlim] += shift * dom_x 666 YShifted[ylim] += shift * dom_y 667 AllX |= xlim 668 AllY |= ylim 669 670 # At this point XShifted and YShifted are the shifted arrays of 671 # position data in data coordinates 672 wI = AllX & AllY 673 674 # This converts XShifted and YShifted into plot coordinates 675 x = ((XShifted[wI] - x0) * dx).ndarray_view() + xx0 676 y = ((YShifted[wI] - y0) * dy).ndarray_view() + yy0 677 z = data[self.field][wI] 678 679 # Both the input and output from the triangulator are in plot 680 # coordinates 681 triangulation = Triangulation(x, y) 682 zi = LinearTriInterpolator(triangulation, z)(xi, yi) 683 elif plot._type_name == "OffAxisProjection": 684 zi = plot.frb[self.field][:: self.factor, :: self.factor].transpose() 685 686 if self.take_log is None: 687 field = data._determine_fields([self.field])[0] 688 self.take_log = plot.ds._get_field_info(*field).take_log 689 690 if self.take_log: 691 zi = np.log10(zi) 692 693 if self.take_log and self.clim is not None: 694 self.clim = (np.log10(self.clim[0]), np.log10(self.clim[1])) 695 696 if self.clim is not None: 697 self.ncont = np.linspace(self.clim[0], self.clim[1], self.ncont) 698 699 cset = plot._axes.contour(xi, yi, zi, self.ncont, **self.plot_args) 700 plot._axes.set_xlim(xx0, xx1) 701 plot._axes.set_ylim(yy0, yy1) 702 703 if self.label: 704 plot._axes.clabel(cset, **self.text_args) 705 706 707class GridBoundaryCallback(PlotCallback): 708 """ 709 Draws grids on an existing PlotWindow object. Adds grid boundaries to a 710 plot, optionally with alpha-blending. By default, colors different levels of 711 grids with different colors going from white to black, but you can change to 712 any arbitrary colormap with cmap keyword, to all black grid edges for all 713 levels with cmap=None and edgecolors=None, or to an arbitrary single color 714 for grid edges with edgecolors='YourChosenColor' defined in any of the 715 standard ways (e.g., edgecolors='white', edgecolors='r', 716 edgecolors='#00FFFF', or edgecolor='0.3', where the last is a float in 0-1 717 scale indicating gray). Note that setting edgecolors overrides cmap if you 718 have both set to non-None values. Cutoff for display is at min_pix 719 wide. draw_ids puts the grid id a the corner of the grid (but its not so 720 great in projections...). id_loc determines which corner holds the grid id. 721 One can set min and maximum level of grids to display, and 722 can change the linewidth of the displayed grids. 723 """ 724 725 _type_name = "grids" 726 _supported_geometries = ("cartesian", "spectral_cube", "cylindrical") 727 728 def __init__( 729 self, 730 alpha=0.7, 731 min_pix=1, 732 min_pix_ids=20, 733 draw_ids=False, 734 id_loc=None, 735 periodic=True, 736 min_level=None, 737 max_level=None, 738 cmap="B-W LINEAR_r", 739 edgecolors=None, 740 linewidth=1.0, 741 ): 742 PlotCallback.__init__(self) 743 self.alpha = alpha 744 self.min_pix = min_pix 745 self.min_pix_ids = min_pix_ids 746 self.draw_ids = draw_ids # put grid numbers in the corner. 747 if id_loc is None: 748 self.id_loc = "lower left" 749 else: 750 self.id_loc = id_loc.lower() # Make case-insensitive 751 if not self.draw_ids: 752 mylog.warning( 753 "Supplied id_loc but draw_ids is False. Not drawing grid ids" 754 ) 755 self.periodic = periodic 756 self.min_level = min_level 757 self.max_level = max_level 758 self.linewidth = linewidth 759 self.cmap = cmap 760 self.edgecolors = edgecolors 761 762 def __call__(self, plot): 763 if plot.data.ds.geometry == "cylindrical" and plot.data.ds.dimensionality == 3: 764 raise NotImplementedError( 765 "Grid annotation is only supported for \ 766 for 2D cylindrical geometry, not 3D" 767 ) 768 from matplotlib.colors import colorConverter 769 770 x0, x1, y0, y1 = self._physical_bounds(plot) 771 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 772 (dx, dy) = self._pixel_scale(plot) 773 (ypix, xpix) = plot.image._A.shape 774 ax = plot.data.axis 775 px_index = plot.data.ds.coordinates.x_axis[ax] 776 py_index = plot.data.ds.coordinates.y_axis[ax] 777 DW = plot.data.ds.domain_width 778 if self.periodic: 779 pxs, pys = np.mgrid[-1:1:3j, -1:1:3j] 780 else: 781 pxs, pys = np.mgrid[0:0:1j, 0:0:1j] 782 GLE, GRE, levels, block_ids = [], [], [], [] 783 for block, _mask in plot.data.blocks: 784 GLE.append(block.LeftEdge.in_units("code_length")) 785 GRE.append(block.RightEdge.in_units("code_length")) 786 levels.append(block.Level) 787 block_ids.append(block.id) 788 if len(GLE) == 0: 789 return 790 # Retain both units and registry 791 GLE = plot.ds.arr(GLE, units=GLE[0].units) 792 GRE = plot.ds.arr(GRE, units=GRE[0].units) 793 levels = np.array(levels) 794 min_level = self.min_level or 0 795 max_level = self.max_level or levels.max() 796 797 # sort the four arrays in order of ascending level, this makes images look nicer 798 new_indices = np.argsort(levels) 799 levels = levels[new_indices] 800 GLE = GLE[new_indices] 801 GRE = GRE[new_indices] 802 block_ids = np.array(block_ids)[new_indices] 803 804 for px_off, py_off in zip(pxs.ravel(), pys.ravel()): 805 pxo = px_off * DW[px_index] 806 pyo = py_off * DW[py_index] 807 left_edge_x = np.array((GLE[:, px_index] + pxo - x0) * dx) + xx0 808 left_edge_y = np.array((GLE[:, py_index] + pyo - y0) * dy) + yy0 809 right_edge_x = np.array((GRE[:, px_index] + pxo - x0) * dx) + xx0 810 right_edge_y = np.array((GRE[:, py_index] + pyo - y0) * dy) + yy0 811 xwidth = xpix * (right_edge_x - left_edge_x) / (xx1 - xx0) 812 ywidth = ypix * (right_edge_y - left_edge_y) / (yy1 - yy0) 813 visible = np.logical_and( 814 np.logical_and(xwidth > self.min_pix, ywidth > self.min_pix), 815 np.logical_and(levels >= min_level, levels <= max_level), 816 ) 817 818 # Grids can either be set by edgecolors OR a colormap. 819 if self.edgecolors is not None: 820 edgecolors = colorConverter.to_rgba(self.edgecolors, alpha=self.alpha) 821 else: # use colormap if not explicitly overridden by edgecolors 822 if self.cmap is not None: 823 color_bounds = [0, plot.data.ds.index.max_level] 824 edgecolors = ( 825 apply_colormap( 826 levels[visible] * 1.0, 827 color_bounds=color_bounds, 828 cmap_name=self.cmap, 829 )[0, :, :] 830 * 1.0 831 / 255.0 832 ) 833 edgecolors[:, 3] = self.alpha 834 else: 835 edgecolors = (0.0, 0.0, 0.0, self.alpha) 836 837 if visible.nonzero()[0].size == 0: 838 continue 839 verts = np.array( 840 [ 841 (left_edge_x, left_edge_x, right_edge_x, right_edge_x), 842 (left_edge_y, right_edge_y, right_edge_y, left_edge_y), 843 ] 844 ) 845 verts = verts.transpose()[visible, :, :] 846 grid_collection = matplotlib.collections.PolyCollection( 847 verts, 848 facecolors="none", 849 edgecolors=edgecolors, 850 linewidth=self.linewidth, 851 ) 852 plot._axes.add_collection(grid_collection) 853 854 visible_ids = np.logical_and( 855 np.logical_and(xwidth > self.min_pix_ids, ywidth > self.min_pix_ids), 856 np.logical_and(levels >= min_level, levels <= max_level), 857 ) 858 859 if self.draw_ids: 860 plot_ids = np.where(visible_ids)[0] 861 x = np.empty(plot_ids.size) 862 y = np.empty(plot_ids.size) 863 for i, n in enumerate(plot_ids): 864 if self.id_loc == "lower left": 865 x[i] = left_edge_x[n] + (2 * (xx1 - xx0) / xpix) 866 y[i] = left_edge_y[n] + (2 * (yy1 - yy0) / ypix) 867 elif self.id_loc == "lower right": 868 x[i] = right_edge_x[n] - ( 869 (10 * len(str(block_ids[i])) - 2) * (xx1 - xx0) / xpix 870 ) 871 y[i] = left_edge_y[n] + (2 * (yy1 - yy0) / ypix) 872 elif self.id_loc == "upper left": 873 x[i] = left_edge_x[n] + (2 * (xx1 - xx0) / xpix) 874 y[i] = right_edge_y[n] - (12 * (yy1 - yy0) / ypix) 875 elif self.id_loc == "upper right": 876 x[i] = right_edge_x[n] - ( 877 (10 * len(str(block_ids[i])) - 2) * (xx1 - xx0) / xpix 878 ) 879 y[i] = right_edge_y[n] - (12 * (yy1 - yy0) / ypix) 880 else: 881 raise RuntimeError( 882 "Unrecognized id_loc value ('%s'). " 883 "Allowed values are 'lower left', lower right', " 884 "'upper left', and 'upper right'." % self.id_loc 885 ) 886 plot._axes.text(x[i], y[i], "%d" % block_ids[n], clip_on=True) 887 888 889class StreamlineCallback(PlotCallback): 890 """ 891 Add streamlines to any plot, using the *field_x* and *field_y* 892 from the associated data, skipping every *factor* datapoints like 893 'quiver'. *density* is the index of the amount of the streamlines. 894 *field_color* is a field to be used to colormap the streamlines. 895 If *display_threshold* is supplied, any streamline segments where 896 *field_color* is less than the threshold will be removed by having 897 their line width set to 0. 898 """ 899 900 _type_name = "streamlines" 901 _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") 902 903 def __init__( 904 self, 905 field_x, 906 field_y, 907 factor=16, 908 density=1, 909 field_color=None, 910 display_threshold=None, 911 plot_args=None, 912 ): 913 PlotCallback.__init__(self) 914 def_plot_args = {} 915 self.field_x = field_x 916 self.field_y = field_y 917 self.field_color = field_color 918 self.factor = factor 919 self.dens = density 920 self.display_threshold = display_threshold 921 if plot_args is None: 922 plot_args = def_plot_args 923 self.plot_args = plot_args 924 925 def __call__(self, plot): 926 bounds = self._physical_bounds(plot) 927 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 928 929 # We are feeding this size into the pixelizer, where it will properly 930 # set it in reverse order 931 nx = plot.image._A.shape[1] // self.factor 932 ny = plot.image._A.shape[0] // self.factor 933 pixX = plot.data.ds.coordinates.pixelize( 934 plot.data.axis, plot.data, self.field_x, bounds, (nx, ny) 935 ) 936 pixY = plot.data.ds.coordinates.pixelize( 937 plot.data.axis, plot.data, self.field_y, bounds, (nx, ny) 938 ) 939 if self.field_color: 940 field_colors = plot.data.ds.coordinates.pixelize( 941 plot.data.axis, plot.data, self.field_color, bounds, (nx, ny) 942 ) 943 944 if self.display_threshold: 945 946 mask = field_colors > self.display_threshold 947 lwdefault = matplotlib.rcParams["lines.linewidth"] 948 949 if "linewidth" in self.plot_args: 950 linewidth = self.plot_args["linewidth"] 951 else: 952 linewidth = lwdefault 953 954 try: 955 linewidth *= mask 956 self.plot_args["linewidth"] = linewidth 957 except ValueError as e: 958 err_msg = ( 959 "Error applying display threshold: linewidth" 960 + "must have shape ({}, {}) or be scalar" 961 ) 962 err_msg = err_msg.format(nx, ny) 963 raise ValueError(err_msg) from e 964 965 else: 966 field_colors = None 967 968 X, Y = ( 969 np.linspace(xx0, xx1, nx, endpoint=True), 970 np.linspace(yy0, yy1, ny, endpoint=True), 971 ) 972 streamplot_args = { 973 "x": X, 974 "y": Y, 975 "u": pixX, 976 "v": pixY, 977 "density": self.dens, 978 "color": field_colors, 979 } 980 streamplot_args.update(self.plot_args) 981 plot._axes.streamplot(**streamplot_args) 982 plot._axes.set_xlim(xx0, xx1) 983 plot._axes.set_ylim(yy0, yy1) 984 985 986class LinePlotCallback(PlotCallback): 987 """ 988 Overplot a line with endpoints at p1 and p2. p1 and p2 989 should be 2D or 3D coordinates consistent with the coordinate 990 system denoted in the "coord_system" keyword. 991 992 Parameters 993 ---------- 994 p1, p2 : 2- or 3-element tuples, lists, or arrays 995 These are the coordinates of the endpoints of the line. 996 997 coord_system : string, optional 998 This string defines the coordinate system of the coordinates p1 and p2. 999 Valid coordinates are: 1000 1001 "data" -- the 3D dataset coordinates 1002 1003 "plot" -- the 2D coordinates defined by the actual plot limits 1004 1005 "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is 1006 upper right 1007 1008 "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) 1009 is upper right 1010 1011 plot_args : dictionary, optional 1012 This dictionary is passed to the MPL plot function for generating 1013 the line. By default, it is: {'color':'white', 'linewidth':2} 1014 1015 Examples 1016 -------- 1017 1018 >>> # Overplot a diagonal white line from the lower left corner to upper 1019 >>> # right corner 1020 >>> import yt 1021 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1022 >>> s = yt.SlicePlot(ds, "z", "density") 1023 >>> s.annotate_line([0, 0], [1, 1], coord_system="axis") 1024 >>> s.save() 1025 1026 >>> # Overplot a red dashed line from data coordinate (0.1, 0.2, 0.3) to 1027 >>> # (0.5, 0.6, 0.7) 1028 >>> import yt 1029 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1030 >>> s = yt.SlicePlot(ds, "z", "density") 1031 >>> s.annotate_line( 1032 ... [0.1, 0.2, 0.3], 1033 ... [0.5, 0.6, 0.7], 1034 ... coord_system="data", 1035 ... plot_args={"color": "red", "lineStyles": "--"}, 1036 ... ) 1037 >>> s.save() 1038 1039 """ 1040 1041 _type_name = "line" 1042 _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") 1043 1044 def __init__(self, p1, p2, data_coords=False, coord_system="data", plot_args=None): 1045 PlotCallback.__init__(self) 1046 def_plot_args = {"color": "white", "linewidth": 2} 1047 self.p1 = p1 1048 self.p2 = p2 1049 if plot_args is None: 1050 plot_args = def_plot_args 1051 self.plot_args = plot_args 1052 if data_coords: 1053 coord_system = "data" 1054 warnings.warn( 1055 "The data_coords keyword is deprecated. Please set " 1056 "the keyword coord_system='data' instead." 1057 ) 1058 self.coord_system = coord_system 1059 self.transform = None 1060 1061 def __call__(self, plot): 1062 p1 = self._sanitize_coord_system(plot, self.p1, coord_system=self.coord_system) 1063 p2 = self._sanitize_coord_system(plot, self.p2, coord_system=self.coord_system) 1064 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 1065 plot._axes.plot( 1066 [p1[0], p2[0]], [p1[1], p2[1]], transform=self.transform, **self.plot_args 1067 ) 1068 plot._axes.set_xlim(xx0, xx1) 1069 plot._axes.set_ylim(yy0, yy1) 1070 1071 1072class ImageLineCallback(LinePlotCallback): 1073 """ 1074 This callback is deprecated, as it is simply a wrapper around 1075 the LinePlotCallback (ie annotate_image()). The only difference is 1076 that it uses coord_system="axis" by default. Please see LinePlotCallback 1077 for more information. 1078 1079 """ 1080 1081 _type_name = "image_line" 1082 _supported_geometries = ("cartesian", "spectral_cube", "cylindrical") 1083 1084 def __init__(self, p1, p2, data_coords=False, coord_system="axis", plot_args=None): 1085 super().__init__(p1, p2, data_coords, coord_system, plot_args) 1086 warnings.warn( 1087 "The ImageLineCallback (annotate_image_line()) is " 1088 "deprecated. Please use the LinePlotCallback " 1089 "(annotate_line()) instead." 1090 ) 1091 1092 def __call__(self, plot): 1093 super().__call__(plot) 1094 1095 1096class CuttingQuiverCallback(PlotCallback): 1097 """ 1098 Get a quiver plot on top of a cutting plane, using *field_x* and 1099 *field_y*, skipping every *factor* datapoint in the discretization. 1100 *scale* is the data units per arrow length unit using *scale_units* 1101 and *plot_args* allows you to pass in matplotlib arguments (see 1102 matplotlib.axes.Axes.quiver for more info). if *normalize* is True, 1103 the fields will be scaled by their local (in-plane) length, allowing 1104 morphological features to be more clearly seen for fields with 1105 substantial variation in field strength. 1106 """ 1107 1108 _type_name = "cquiver" 1109 _supported_geometries = ("cartesian", "spectral_cube") 1110 1111 def __init__( 1112 self, 1113 field_x, 1114 field_y, 1115 factor=16, 1116 scale=None, 1117 scale_units=None, 1118 normalize=False, 1119 plot_args=None, 1120 ): 1121 PlotCallback.__init__(self) 1122 self.field_x = field_x 1123 self.field_y = field_y 1124 self.factor = factor 1125 self.scale = scale 1126 self.scale_units = scale_units 1127 self.normalize = normalize 1128 if plot_args is None: 1129 plot_args = {} 1130 self.plot_args = plot_args 1131 1132 def __call__(self, plot): 1133 x0, x1, y0, y1 = self._physical_bounds(plot) 1134 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 1135 nx = plot.image._A.shape[1] // self.factor 1136 ny = plot.image._A.shape[0] // self.factor 1137 indices = np.argsort(plot.data["index", "dx"])[::-1].astype(np.int_) 1138 1139 pixX = np.zeros((ny, nx), dtype="f8") 1140 pixY = np.zeros((ny, nx), dtype="f8") 1141 pixelize_off_axis_cartesian( 1142 pixX, 1143 plot.data[("index", "x")].to("code_length"), 1144 plot.data[("index", "y")].to("code_length"), 1145 plot.data[("index", "z")].to("code_length"), 1146 plot.data["px"], 1147 plot.data["py"], 1148 plot.data["pdx"], 1149 plot.data["pdy"], 1150 plot.data["pdz"], 1151 plot.data.center, 1152 plot.data._inv_mat, 1153 indices, 1154 plot.data[self.field_x], 1155 (x0, x1, y0, y1), 1156 ) 1157 pixelize_off_axis_cartesian( 1158 pixY, 1159 plot.data[("index", "x")].to("code_length"), 1160 plot.data[("index", "y")].to("code_length"), 1161 plot.data[("index", "z")].to("code_length"), 1162 plot.data["px"], 1163 plot.data["py"], 1164 plot.data["pdx"], 1165 plot.data["pdy"], 1166 plot.data["pdz"], 1167 plot.data.center, 1168 plot.data._inv_mat, 1169 indices, 1170 plot.data[self.field_y], 1171 (x0, x1, y0, y1), 1172 ) 1173 X, Y = np.meshgrid( 1174 np.linspace(xx0, xx1, nx, endpoint=True), 1175 np.linspace(yy0, yy1, ny, endpoint=True), 1176 ) 1177 1178 if self.normalize: 1179 nn = np.sqrt(pixX ** 2 + pixY ** 2) 1180 pixX /= nn 1181 pixY /= nn 1182 1183 plot._axes.quiver( 1184 X, 1185 Y, 1186 pixX, 1187 pixY, 1188 scale=self.scale, 1189 scale_units=self.scale_units, 1190 **self.plot_args, 1191 ) 1192 plot._axes.set_xlim(xx0, xx1) 1193 plot._axes.set_ylim(yy0, yy1) 1194 1195 1196class ClumpContourCallback(PlotCallback): 1197 """ 1198 Take a list of *clumps* and plot them as a set of contours. 1199 """ 1200 1201 _type_name = "clumps" 1202 _supported_geometries = ("cartesian", "spectral_cube", "cylindrical") 1203 1204 def __init__(self, clumps, plot_args=None): 1205 self.clumps = clumps 1206 if plot_args is None: 1207 plot_args = {} 1208 if "color" in plot_args: 1209 plot_args["colors"] = plot_args.pop("color") 1210 self.plot_args = plot_args 1211 1212 def __call__(self, plot): 1213 bounds = self._physical_bounds(plot) 1214 extent = self._plot_bounds(plot) 1215 1216 ax = plot.data.axis 1217 px_index = plot.data.ds.coordinates.x_axis[ax] 1218 py_index = plot.data.ds.coordinates.y_axis[ax] 1219 1220 xf = plot.data.ds.coordinates.axis_name[px_index] 1221 yf = plot.data.ds.coordinates.axis_name[py_index] 1222 dxf = f"d{xf}" 1223 dyf = f"d{yf}" 1224 1225 ny, nx = plot.image._A.shape 1226 buff = np.zeros((nx, ny), dtype="float64") 1227 for i, clump in enumerate(reversed(self.clumps)): 1228 mylog.info("Pixelizing contour %s", i) 1229 1230 if isinstance(clump, Clump): 1231 ftype = "index" 1232 elif isinstance(clump, YTClumpContainer): 1233 ftype = "grid" 1234 else: 1235 raise RuntimeError( 1236 f"Unknown field type for object of type {type(clump)}." 1237 ) 1238 1239 xf_copy = clump[ftype, xf].copy().in_units("code_length") 1240 yf_copy = clump[ftype, yf].copy().in_units("code_length") 1241 1242 temp = np.zeros((ny, nx), dtype="f8") 1243 pixelize_cartesian( 1244 temp, 1245 xf_copy, 1246 yf_copy, 1247 clump[ftype, dxf].in_units("code_length") / 2.0, 1248 clump[ftype, dyf].in_units("code_length") / 2.0, 1249 clump[ftype, dxf].d * 0.0 + i + 1, # inits inside Pixelize 1250 bounds, 1251 0, 1252 ) 1253 buff = np.maximum(temp, buff) 1254 self.rv = plot._axes.contour( 1255 buff, np.unique(buff), extent=extent, **self.plot_args 1256 ) 1257 1258 1259class ArrowCallback(PlotCallback): 1260 """ 1261 Overplot arrow(s) pointing at position(s) for highlighting specific 1262 features. By default, arrow points from lower left to the designated 1263 position "pos" with arrow length "length". Alternatively, if 1264 "starting_pos" is set, arrow will stretch from "starting_pos" to "pos" 1265 and "length" will be disregarded. 1266 1267 "coord_system" keyword refers to positions set in "pos" arg and 1268 "starting_pos" keyword, which by default are in data coordinates. 1269 1270 "length", "width", "head_length", and "head_width" keywords for the arrow 1271 are all in axis units, ie relative to the size of the plot axes as 1, 1272 even if the position of the arrow is set relative to another coordinate 1273 system. 1274 1275 Parameters 1276 ---------- 1277 pos : array-like 1278 These are the coordinates where the marker(s) will be overplotted 1279 Either as [x,y,z] or as [[x1,x2,...],[y1,y2,...],[z1,z2,...]] 1280 1281 length : float, optional 1282 The length, in axis units, of the arrow. 1283 Default: 0.03 1284 1285 width : float, optional 1286 The width, in axis units, of the tail line of the arrow. 1287 Default: 0.003 1288 1289 head_length : float, optional 1290 The length, in axis units, of the head of the arrow. If set 1291 to None, use 1.5*head_width 1292 Default: None 1293 1294 head_width : float, optional 1295 The width, in axis units, of the head of the arrow. 1296 Default: 0.02 1297 1298 starting_pos : 2- or 3-element tuple, list, or array, optional 1299 These are the coordinates from which the arrow starts towards its 1300 point. Not compatible with 'length' kwarg. 1301 1302 coord_system : string, optional 1303 This string defines the coordinate system of the coordinates of pos 1304 Valid coordinates are: 1305 1306 "data" -- the 3D dataset coordinates 1307 1308 "plot" -- the 2D coordinates defined by the actual plot limits 1309 1310 "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is 1311 upper right 1312 1313 "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) 1314 is upper right 1315 1316 plot_args : dictionary, optional 1317 This dictionary is passed to the MPL arrow function for generating 1318 the arrow. By default, it is: {'color':'white'} 1319 1320 Examples 1321 -------- 1322 1323 >>> # Overplot an arrow pointing to feature at data coord: (0.2, 0.3, 0.4) 1324 >>> import yt 1325 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1326 >>> s = yt.SlicePlot(ds, "z", "density") 1327 >>> s.annotate_arrow([0.2, 0.3, 0.4]) 1328 >>> s.save() 1329 1330 >>> # Overplot a red arrow with longer length pointing to plot coordinate 1331 >>> # (0.1, -0.1) 1332 >>> import yt 1333 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1334 >>> s = yt.SlicePlot(ds, "z", "density") 1335 >>> s.annotate_arrow( 1336 ... [0.1, -0.1], length=0.06, coord_system="plot", plot_args={"color": "red"} 1337 ... ) 1338 >>> s.save() 1339 1340 """ 1341 1342 _type_name = "arrow" 1343 _supported_geometries = ("cartesian", "spectral_cube", "cylindrical") 1344 1345 def __init__( 1346 self, 1347 pos, 1348 code_size=None, 1349 length=0.03, 1350 width=0.0001, 1351 head_width=0.01, 1352 head_length=0.01, 1353 starting_pos=None, 1354 coord_system="data", 1355 plot_args=None, 1356 ): 1357 def_plot_args = {"color": "white"} 1358 self.pos = pos 1359 self.code_size = code_size 1360 self.length = length 1361 self.width = width 1362 self.head_width = head_width 1363 self.head_length = head_length 1364 self.starting_pos = starting_pos 1365 self.coord_system = coord_system 1366 self.transform = None 1367 if plot_args is None: 1368 plot_args = def_plot_args 1369 self.plot_args = plot_args 1370 1371 def __call__(self, plot): 1372 x, y = self._sanitize_coord_system( 1373 plot, self.pos, coord_system=self.coord_system 1374 ) 1375 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 1376 # normalize all of the kwarg lengths to the plot size 1377 plot_diag = ((yy1 - yy0) ** 2 + (xx1 - xx0) ** 2) ** (0.5) 1378 self.length *= plot_diag 1379 self.width *= plot_diag 1380 self.head_width *= plot_diag 1381 if self.head_length is not None: 1382 self.head_length *= plot_diag 1383 if self.code_size is not None: 1384 warnings.warn( 1385 "The code_size keyword is deprecated. Please use " 1386 "the length keyword in 'axis' units instead. " 1387 "Setting code_size overrides length value." 1388 ) 1389 if is_sequence(self.code_size): 1390 self.code_size = plot.data.ds.quan(self.code_size[0], self.code_size[1]) 1391 self.code_size = np.float64(self.code_size.in_units(plot.xlim[0].units)) 1392 self.code_size = self.code_size * self._pixel_scale(plot)[0] 1393 dx = dy = self.code_size 1394 else: 1395 if self.starting_pos is not None: 1396 start_x, start_y = self._sanitize_coord_system( 1397 plot, self.starting_pos, coord_system=self.coord_system 1398 ) 1399 dx = x - start_x 1400 dy = y - start_y 1401 else: 1402 dx = (xx1 - xx0) * 2 ** (0.5) * self.length 1403 dy = (yy1 - yy0) * 2 ** (0.5) * self.length 1404 # If the arrow is 0 length 1405 if dx == dy == 0: 1406 warnings.warn("The arrow has zero length. Not annotating.") 1407 return 1408 try: 1409 plot._axes.arrow( 1410 x - dx, 1411 y - dy, 1412 dx, 1413 dy, 1414 width=self.width, 1415 head_width=self.head_width, 1416 head_length=self.head_length, 1417 transform=self.transform, 1418 length_includes_head=True, 1419 **self.plot_args, 1420 ) 1421 except ValueError: 1422 for i in range(len(x)): 1423 plot._axes.arrow( 1424 x[i] - dx, 1425 y[i] - dy, 1426 dx, 1427 dy, 1428 width=self.width, 1429 head_width=self.head_width, 1430 head_length=self.head_length, 1431 transform=self.transform, 1432 length_includes_head=True, 1433 **self.plot_args, 1434 ) 1435 plot._axes.set_xlim(xx0, xx1) 1436 plot._axes.set_ylim(yy0, yy1) 1437 1438 1439class MarkerAnnotateCallback(PlotCallback): 1440 """ 1441 Overplot marker(s) at a position(s) for highlighting specific features. 1442 1443 Parameters 1444 ---------- 1445 pos : array-like 1446 These are the coordinates where the marker(s) will be overplotted 1447 Either as [x,y,z] or as [[x1,x2,...],[y1,y2,...],[z1,z2,...]] 1448 1449 marker : string, optional 1450 The shape of the marker to be passed to the MPL scatter function. 1451 By default, it is 'x', but other acceptable values are: '.', 'o', 'v', 1452 '^', 's', 'p' '*', etc. See matplotlib.markers for more information. 1453 1454 coord_system : string, optional 1455 This string defines the coordinate system of the coordinates of pos 1456 Valid coordinates are: 1457 1458 "data" -- the 3D dataset coordinates 1459 1460 "plot" -- the 2D coordinates defined by the actual plot limits 1461 1462 "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is 1463 upper right 1464 1465 "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) 1466 is upper right 1467 1468 plot_args : dictionary, optional 1469 This dictionary is passed to the MPL scatter function for generating 1470 the marker. By default, it is: {'color':'white', 's':50} 1471 1472 Examples 1473 -------- 1474 1475 >>> # Overplot a white X on a feature at data location (0.5, 0.5, 0.5) 1476 >>> import yt 1477 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1478 >>> s = yt.SlicePlot(ds, "z", "density") 1479 >>> s.annotate_marker([0.4, 0.5, 0.6]) 1480 >>> s.save() 1481 1482 >>> # Overplot a big yellow circle at axis location (0.1, 0.2) 1483 >>> import yt 1484 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1485 >>> s = yt.SlicePlot(ds, "z", "density") 1486 >>> s.annotate_marker( 1487 ... [0.1, 0.2], 1488 ... marker="o", 1489 ... coord_system="axis", 1490 ... plot_args={"color": "yellow", "s": 200}, 1491 ... ) 1492 >>> s.save() 1493 1494 """ 1495 1496 _type_name = "marker" 1497 _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") 1498 1499 def __init__(self, pos, marker="x", coord_system="data", plot_args=None): 1500 def_plot_args = {"color": "w", "s": 50} 1501 self.pos = pos 1502 self.marker = marker 1503 if plot_args is None: 1504 plot_args = def_plot_args 1505 self.plot_args = plot_args 1506 self.coord_system = coord_system 1507 self.transform = None 1508 1509 def __call__(self, plot): 1510 x, y = self._sanitize_coord_system( 1511 plot, self.pos, coord_system=self.coord_system 1512 ) 1513 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 1514 plot._axes.scatter( 1515 x, y, marker=self.marker, transform=self.transform, **self.plot_args 1516 ) 1517 plot._axes.set_xlim(xx0, xx1) 1518 plot._axes.set_ylim(yy0, yy1) 1519 1520 1521class SphereCallback(PlotCallback): 1522 """ 1523 Overplot a circle with designated center and radius with optional text. 1524 1525 Parameters 1526 ---------- 1527 center : 2- or 3-element tuple, list, or array 1528 These are the coordinates where the circle will be overplotted 1529 1530 radius : YTArray, float, or (1, ('kpc')) style tuple 1531 The radius of the circle in code coordinates 1532 1533 circle_args : dict, optional 1534 This dictionary is passed to the MPL circle object. By default, 1535 {'color':'white'} 1536 1537 coord_system : string, optional 1538 This string defines the coordinate system of the coordinates of pos 1539 Valid coordinates are: 1540 1541 "data" -- the 3D dataset coordinates 1542 1543 "plot" -- the 2D coordinates defined by the actual plot limits 1544 1545 "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is 1546 upper right 1547 1548 "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) 1549 is upper right 1550 1551 text : string, optional 1552 Optional text to include next to the circle. 1553 1554 text_args : dictionary, optional 1555 This dictionary is passed to the MPL text function. By default, 1556 it is: {'color':'white'} 1557 1558 Examples 1559 -------- 1560 1561 >>> # Overplot a white circle of radius 100 kpc over the central galaxy 1562 >>> import yt 1563 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1564 >>> s = yt.SlicePlot(ds, "z", "density") 1565 >>> s.annotate_sphere([0.5, 0.5, 0.5], radius=(100, "kpc")) 1566 >>> s.save() 1567 1568 """ 1569 1570 _type_name = "sphere" 1571 _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") 1572 1573 def __init__( 1574 self, 1575 center, 1576 radius, 1577 circle_args=None, 1578 text=None, 1579 coord_system="data", 1580 text_args=None, 1581 ): 1582 def_text_args = {"color": "white"} 1583 def_circle_args = {"color": "white"} 1584 self.center = center 1585 self.radius = radius 1586 if circle_args is None: 1587 circle_args = def_circle_args 1588 if "fill" not in circle_args: 1589 circle_args["fill"] = False 1590 self.circle_args = circle_args 1591 self.text = text 1592 if text_args is None: 1593 text_args = def_text_args 1594 self.text_args = text_args 1595 self.coord_system = coord_system 1596 self.transform = None 1597 1598 def __call__(self, plot): 1599 from matplotlib.patches import Circle 1600 1601 if is_sequence(self.radius): 1602 self.radius = plot.data.ds.quan(self.radius[0], self.radius[1]) 1603 self.radius = np.float64(self.radius.in_units(plot.xlim[0].units)) 1604 if isinstance(self.radius, YTQuantity): 1605 if isinstance(self.center, YTArray): 1606 units = self.center.units 1607 else: 1608 units = "code_length" 1609 self.radius = self.radius.to(units) 1610 1611 # This assures the radius has the appropriate size in 1612 # the different coordinate systems, since one cannot simply 1613 # apply a different transform for a length in the same way 1614 # you can for a coordinate. 1615 if self.coord_system == "data" or self.coord_system == "plot": 1616 self.radius = self.radius * self._pixel_scale(plot)[0] 1617 else: 1618 self.radius /= (plot.xlim[1] - plot.xlim[0]).v 1619 1620 x, y = self._sanitize_coord_system( 1621 plot, self.center, coord_system=self.coord_system 1622 ) 1623 1624 cir = Circle((x, y), self.radius, transform=self.transform, **self.circle_args) 1625 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 1626 1627 plot._axes.add_patch(cir) 1628 if self.text is not None: 1629 label = plot._axes.text( 1630 x, y, self.text, transform=self.transform, **self.text_args 1631 ) 1632 self._set_font_properties(plot, [label], **self.text_args) 1633 1634 plot._axes.set_xlim(xx0, xx1) 1635 plot._axes.set_ylim(yy0, yy1) 1636 1637 1638class TextLabelCallback(PlotCallback): 1639 """ 1640 Overplot text on the plot at a specified position. If you desire an inset 1641 box around your text, set one with the inset_box_args dictionary 1642 keyword. 1643 1644 Parameters 1645 ---------- 1646 pos : 2- or 3-element tuple, list, or array 1647 These are the coordinates where the text will be overplotted 1648 1649 text : string 1650 The text you wish to include 1651 1652 coord_system : string, optional 1653 This string defines the coordinate system of the coordinates of pos 1654 Valid coordinates are: 1655 1656 "data" -- the 3D dataset coordinates 1657 1658 "plot" -- the 2D coordinates defined by the actual plot limits 1659 1660 "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is 1661 upper right 1662 1663 "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) 1664 is upper right 1665 1666 text_args : dictionary, optional 1667 This dictionary is passed to the MPL text function for generating 1668 the text. By default, it is: {'color':'white'} and uses the defaults 1669 for the other fonts in the image. 1670 1671 inset_box_args : dictionary, optional 1672 A dictionary of any arbitrary parameters to be passed to the Matplotlib 1673 FancyBboxPatch object as the inset box around the text. Default: {} 1674 1675 Examples 1676 -------- 1677 1678 >>> # Overplot white text at data location [0.55, 0.7, 0.4] 1679 >>> import yt 1680 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1681 >>> s = yt.SlicePlot(ds, "z", "density") 1682 >>> s.annotate_text([0.55, 0.7, 0.4], "Here is a galaxy") 1683 >>> s.save() 1684 1685 >>> # Overplot yellow text at axis location [0.2, 0.8] with 1686 >>> # a shaded inset box 1687 >>> import yt 1688 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1689 >>> s = yt.SlicePlot(ds, "z", "density") 1690 >>> s.annotate_text( 1691 ... [0.2, 0.8], 1692 ... "Here is a galaxy", 1693 ... coord_system="axis", 1694 ... text_args={"color": "yellow"}, 1695 ... inset_box_args={ 1696 ... "boxstyle": "square,pad=0.3", 1697 ... "facecolor": "black", 1698 ... "linewidth": 3, 1699 ... "edgecolor": "white", 1700 ... "alpha": 0.5, 1701 ... }, 1702 ... ) 1703 >>> s.save() 1704 """ 1705 1706 _type_name = "text" 1707 _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") 1708 1709 def __init__( 1710 self, 1711 pos, 1712 text, 1713 data_coords=False, 1714 coord_system="data", 1715 text_args=None, 1716 inset_box_args=None, 1717 ): 1718 def_text_args = {"color": "white"} 1719 self.pos = pos 1720 self.text = text 1721 if data_coords: 1722 coord_system = "data" 1723 warnings.warn( 1724 "The data_coords keyword is deprecated. Please set " 1725 "the keyword coord_system='data' instead." 1726 ) 1727 if text_args is None: 1728 text_args = def_text_args 1729 self.text_args = text_args 1730 self.inset_box_args = inset_box_args 1731 self.coord_system = coord_system 1732 self.transform = None 1733 1734 def __call__(self, plot): 1735 kwargs = self.text_args.copy() 1736 x, y = self._sanitize_coord_system( 1737 plot, self.pos, coord_system=self.coord_system 1738 ) 1739 1740 # Set the font properties of text from this callback to be 1741 # consistent with other text labels in this figure 1742 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 1743 if self.inset_box_args is not None: 1744 kwargs["bbox"] = self.inset_box_args 1745 label = plot._axes.text(x, y, self.text, transform=self.transform, **kwargs) 1746 self._set_font_properties(plot, [label], **kwargs) 1747 plot._axes.set_xlim(xx0, xx1) 1748 plot._axes.set_ylim(yy0, yy1) 1749 1750 1751class PointAnnotateCallback(TextLabelCallback): 1752 """ 1753 This callback is deprecated, as it is simply a wrapper around 1754 the TextLabelCallback (ie annotate_text()). Please see TextLabelCallback 1755 for more information. 1756 1757 """ 1758 1759 _type_name = "point" 1760 _supported_geometries = ("cartesian", "spectral_cube", "cylindrical") 1761 1762 def __init__( 1763 self, 1764 pos, 1765 text, 1766 data_coords=False, 1767 coord_system="data", 1768 text_args=None, 1769 inset_box_args=None, 1770 ): 1771 super().__init__( 1772 pos, text, data_coords, coord_system, text_args, inset_box_args 1773 ) 1774 warnings.warn( 1775 "The PointAnnotateCallback (annotate_point()) is " 1776 "deprecated. Please use the TextLabelCallback " 1777 "(annotate_point()) instead." 1778 ) 1779 1780 def __call__(self, plot): 1781 super().__call__(plot) 1782 1783 1784class HaloCatalogCallback(PlotCallback): 1785 """ 1786 Plots circles at the locations of all the halos 1787 in a halo catalog with radii corresponding to the 1788 virial radius of each halo. 1789 1790 Note, this functionality requires the yt_astro_analysis 1791 package. See https://yt-astro-analysis.readthedocs.io/ 1792 for more information. 1793 1794 Parameters 1795 ---------- 1796 halo_catalog : Dataset, DataContainer, 1797 or ~yt.analysis_modules.halo_analysis.halo_catalog.HaloCatalog 1798 The object containing halos to be overplotted. This can 1799 be a HaloCatalog object, a loaded halo catalog dataset, 1800 or a data container from a halo catalog dataset. 1801 circle_args : list 1802 Contains the arguments controlling the 1803 appearance of the circles, supplied to the 1804 Matplotlib patch Circle. 1805 width : tuple 1806 The width over which to select halos to plot, 1807 useful when overplotting to a slice plot. Accepts 1808 a tuple in the form (1.0, 'Mpc'). 1809 annotate_field : str 1810 A field contained in the 1811 halo catalog to add text to the plot near the halo. 1812 Example: annotate_field = 'particle_mass' will 1813 write the halo mass next to each halo. 1814 radius_field : str 1815 A field contained in the halo 1816 catalog to set the radius of the circle which will 1817 surround each halo. Default: 'virial_radius'. 1818 center_field_prefix : str 1819 Accepts a field prefix which will 1820 be used to find the fields containing the coordinates 1821 of the center of each halo. Ex: 'particle_position' 1822 will result in the fields 'particle_position_x' for x 1823 'particle_position_y' for y, and 'particle_position_z' 1824 for z. Default: 'particle_position'. 1825 text_args : dict 1826 Contains the arguments controlling the text 1827 appearance of the annotated field. 1828 factor : float 1829 A number the virial radius is multiplied by for 1830 plotting the circles. Ex: factor = 2.0 will plot 1831 circles with twice the radius of each halo virial radius. 1832 1833 Examples 1834 -------- 1835 1836 >>> import yt 1837 >>> dds = yt.load("Enzo_64/DD0043/data0043") 1838 >>> hds = yt.load("rockstar_halos/halos_0.0.bin") 1839 >>> p = yt.ProjectionPlot( 1840 ... dds, "x", ("gas", "density"), weight_field=("gas", "density") 1841 ... ) 1842 >>> p.annotate_halos(hds) 1843 >>> p.save() 1844 1845 >>> # plot a subset of all halos 1846 >>> import yt 1847 >>> dds = yt.load("Enzo_64/DD0043/data0043") 1848 >>> hds = yt.load("rockstar_halos/halos_0.0.bin") 1849 >>> # make a region half the width of the box 1850 >>> dregion = dds.box( 1851 ... dds.domain_center - 0.25 * dds.domain_width, 1852 ... dds.domain_center + 0.25 * dds.domain_width, 1853 ... ) 1854 >>> hregion = hds.box( 1855 ... hds.domain_center - 0.25 * hds.domain_width, 1856 ... hds.domain_center + 0.25 * hds.domain_width, 1857 ... ) 1858 >>> p = yt.ProjectionPlot( 1859 ... dds, 1860 ... "x", 1861 ... ("gas", "density"), 1862 ... weight_field=("gas", "density"), 1863 ... data_source=dregion, 1864 ... width=0.5, 1865 ... ) 1866 >>> p.annotate_halos(hregion) 1867 >>> p.save() 1868 1869 >>> # plot halos from a HaloCatalog 1870 >>> import yt 1871 >>> from yt.extensions.astro_analysis.halo_analysis.api import HaloCatalog 1872 >>> dds = yt.load("Enzo_64/DD0043/data0043") 1873 >>> hds = yt.load("rockstar_halos/halos_0.0.bin") 1874 >>> hc = HaloCatalog(data_ds=dds, halos_ds=hds) 1875 >>> p = yt.ProjectionPlot( 1876 ... dds, "x", ("gas", "density"), weight_field=("gas", "density") 1877 ... ) 1878 >>> p.annotate_halos(hc) 1879 >>> p.save() 1880 1881 """ 1882 1883 _type_name = "halos" 1884 region = None 1885 _descriptor = None 1886 _supported_geometries = ("cartesian", "spectral_cube") 1887 1888 def __init__( 1889 self, 1890 halo_catalog, 1891 circle_args=None, 1892 circle_kwargs=None, 1893 width=None, 1894 annotate_field=None, 1895 radius_field="virial_radius", 1896 center_field_prefix="particle_position", 1897 text_args=None, 1898 font_kwargs=None, 1899 factor=1.0, 1900 ): 1901 1902 try: 1903 from yt_astro_analysis.halo_analysis.api import HaloCatalog 1904 except ImportError: 1905 HaloCatalog = NotAModule("yt_astro_analysis") 1906 1907 PlotCallback.__init__(self) 1908 def_circle_args = {"edgecolor": "white", "facecolor": "None"} 1909 def_text_args = {"color": "white"} 1910 1911 if isinstance(halo_catalog, YTDataContainer): 1912 self.halo_data = halo_catalog 1913 elif isinstance(halo_catalog, Dataset): 1914 self.halo_data = halo_catalog.all_data() 1915 elif isinstance(halo_catalog, HaloCatalog): 1916 if halo_catalog.data_source.ds == halo_catalog.halos_ds: 1917 self.halo_data = halo_catalog.data_source 1918 else: 1919 self.halo_data = halo_catalog.halos_ds.all_data() 1920 else: 1921 raise RuntimeError( 1922 "halo_catalog argument must be a HaloCatalog object, " 1923 + "a dataset, or a data container." 1924 ) 1925 1926 self.width = width 1927 self.radius_field = radius_field 1928 self.center_field_prefix = center_field_prefix 1929 self.annotate_field = annotate_field 1930 if circle_kwargs is not None: 1931 circle_args = circle_kwargs 1932 warnings.warn( 1933 "The circle_kwargs keyword is deprecated. Please " 1934 "use the circle_args keyword instead." 1935 ) 1936 if font_kwargs is not None: 1937 text_args = font_kwargs 1938 warnings.warn( 1939 "The font_kwargs keyword is deprecated. Please use " 1940 "the text_args keyword instead." 1941 ) 1942 if circle_args is None: 1943 circle_args = def_circle_args 1944 self.circle_args = circle_args 1945 if text_args is None: 1946 text_args = def_text_args 1947 self.text_args = text_args 1948 self.factor = factor 1949 1950 def __call__(self, plot): 1951 from matplotlib.patches import Circle 1952 1953 data = plot.data 1954 x0, x1, y0, y1 = self._physical_bounds(plot) 1955 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 1956 1957 halo_data = self.halo_data 1958 axis_names = plot.data.ds.coordinates.axis_name 1959 xax = plot.data.ds.coordinates.x_axis[data.axis] 1960 yax = plot.data.ds.coordinates.y_axis[data.axis] 1961 field_x = f"{self.center_field_prefix}_{axis_names[xax]}" 1962 field_y = f"{self.center_field_prefix}_{axis_names[yax]}" 1963 field_z = f"{self.center_field_prefix}_{axis_names[data.axis]}" 1964 1965 # Set up scales for pixel size and original data 1966 pixel_scale = self._pixel_scale(plot)[0] 1967 units = plot.xlim[0].units 1968 1969 # Convert halo positions to code units of the plotted data 1970 # and then to units of the plotted window 1971 px = halo_data[("all", field_x)][:].in_units(units) 1972 py = halo_data[("all", field_y)][:].in_units(units) 1973 1974 xplotcenter = (plot.xlim[0] + plot.xlim[1]) / 2 1975 yplotcenter = (plot.ylim[0] + plot.ylim[1]) / 2 1976 1977 xdomaincenter = plot.ds.domain_center[xax] 1978 ydomaincenter = plot.ds.domain_center[yax] 1979 1980 xoffset = xplotcenter - xdomaincenter 1981 yoffset = yplotcenter - ydomaincenter 1982 1983 xdw = plot.ds.domain_width[xax].to(units) 1984 ydw = plot.ds.domain_width[yax].to(units) 1985 1986 modpx = np.mod(px - xoffset, xdw) + xoffset 1987 modpy = np.mod(py - yoffset, ydw) + yoffset 1988 1989 px[modpx != px] = modpx[modpx != px] 1990 py[modpy != py] = modpy[modpy != py] 1991 1992 px, py = self._convert_to_plot(plot, [px, py]) 1993 1994 # Convert halo radii to a radius in pixels 1995 radius = halo_data[("all", self.radius_field)][:].in_units(units) 1996 radius = np.array(radius * pixel_scale * self.factor) 1997 1998 if self.width: 1999 pz = halo_data[("all", field_z)][:].in_units("code_length") 2000 c = data.center[data.axis] 2001 2002 # I should catch an error here if width isn't in this form 2003 # but I dont really want to reimplement get_sanitized_width... 2004 width = data.ds.arr(self.width[0], self.width[1]).in_units("code_length") 2005 2006 indices = np.where((pz > c - 0.5 * width) & (pz < c + 0.5 * width)) 2007 2008 px = px[indices] 2009 py = py[indices] 2010 radius = radius[indices] 2011 2012 for x, y, r in zip(px, py, radius): 2013 plot._axes.add_artist(Circle(xy=(x, y), radius=r, **self.circle_args)) 2014 2015 plot._axes.set_xlim(xx0, xx1) 2016 plot._axes.set_ylim(yy0, yy1) 2017 2018 if self.annotate_field: 2019 annotate_dat = halo_data[("all", self.annotate_field)] 2020 texts = [f"{float(dat):g}" for dat in annotate_dat] 2021 labels = [] 2022 for pos_x, pos_y, t in zip(px, py, texts): 2023 labels.append(plot._axes.text(pos_x, pos_y, t, **self.text_args)) 2024 2025 # Set the font properties of text from this callback to be 2026 # consistent with other text labels in this figure 2027 self._set_font_properties(plot, labels, **self.text_args) 2028 2029 2030class ParticleCallback(PlotCallback): 2031 """ 2032 Adds particle positions, based on a thick slab along *axis* with a 2033 *width* along the line of sight. *p_size* controls the number of 2034 pixels per particle, and *col* governs the color. *ptype* will 2035 restrict plotted particles to only those that are of a given type. 2036 *alpha* determines the opacity of the marker symbol used in the scatter. 2037 An alternate data source can be specified with *data_source*, but by 2038 default the plot's data source will be queried. 2039 """ 2040 2041 _type_name = "particles" 2042 region = None 2043 _descriptor = None 2044 _supported_geometries = ("cartesian", "spectral_cube", "cylindrical") 2045 2046 def __init__( 2047 self, 2048 width, 2049 p_size=1.0, 2050 col="k", 2051 marker="o", 2052 stride=1, 2053 ptype="all", 2054 minimum_mass=None, 2055 alpha=1.0, 2056 data_source=None, 2057 ): 2058 PlotCallback.__init__(self) 2059 self.width = width 2060 self.p_size = p_size 2061 self.color = col 2062 self.marker = marker 2063 self.stride = stride 2064 self.ptype = ptype 2065 self.minimum_mass = minimum_mass 2066 self.alpha = alpha 2067 self.data_source = data_source 2068 if self.minimum_mass is not None: 2069 warnings.warn( 2070 "The minimum_mass keyword is deprecated. Please use " 2071 "an appropriate particle filter and the ptype keyword instead." 2072 ) 2073 2074 def __call__(self, plot): 2075 data = plot.data 2076 if is_sequence(self.width): 2077 validate_width_tuple(self.width) 2078 self.width = plot.data.ds.quan(self.width[0], self.width[1]) 2079 elif isinstance(self.width, YTQuantity): 2080 self.width = plot.data.ds.quan(self.width.value, self.width.units) 2081 else: 2082 self.width = plot.data.ds.quan(self.width, "code_length") 2083 # we construct a rectangular prism 2084 x0, x1, y0, y1 = self._physical_bounds(plot) 2085 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 2086 if isinstance(self.data_source, YTCutRegion): 2087 mylog.warning( 2088 "Parameter 'width' is ignored in annotate_particles if the " 2089 "data_source is a cut_region. " 2090 "See https://github.com/yt-project/yt/issues/1933 for further details." 2091 ) 2092 self.region = self.data_source 2093 else: 2094 self.region = self._get_region((x0, x1), (y0, y1), plot.data.axis, data) 2095 ax = data.axis 2096 xax = plot.data.ds.coordinates.x_axis[ax] 2097 yax = plot.data.ds.coordinates.y_axis[ax] 2098 axis_names = plot.data.ds.coordinates.axis_name 2099 field_x = f"particle_position_{axis_names[xax]}" 2100 field_y = f"particle_position_{axis_names[yax]}" 2101 pt = self.ptype 2102 self.periodic_x = plot.data.ds.periodicity[xax] 2103 self.periodic_y = plot.data.ds.periodicity[yax] 2104 self.LE = plot.data.ds.domain_left_edge[xax], plot.data.ds.domain_left_edge[yax] 2105 self.RE = ( 2106 plot.data.ds.domain_right_edge[xax], 2107 plot.data.ds.domain_right_edge[yax], 2108 ) 2109 period_x = plot.data.ds.domain_width[xax] 2110 period_y = plot.data.ds.domain_width[yax] 2111 particle_x, particle_y = self._enforce_periodic( 2112 self.region[pt, field_x], 2113 self.region[pt, field_y], 2114 x0, 2115 x1, 2116 period_x, 2117 y0, 2118 y1, 2119 period_y, 2120 ) 2121 gg = ( 2122 (particle_x >= x0) 2123 & (particle_x <= x1) 2124 & (particle_y >= y0) 2125 & (particle_y <= y1) 2126 ) 2127 if self.minimum_mass is not None: 2128 gg &= self.region[pt, "particle_mass"] >= self.minimum_mass 2129 if gg.sum() == 0: 2130 return 2131 px, py = [particle_x[gg][:: self.stride], particle_y[gg][:: self.stride]] 2132 px, py = self._convert_to_plot(plot, [px, py]) 2133 plot._axes.scatter( 2134 px, 2135 py, 2136 edgecolors="None", 2137 marker=self.marker, 2138 s=self.p_size, 2139 c=self.color, 2140 alpha=self.alpha, 2141 ) 2142 plot._axes.set_xlim(xx0, xx1) 2143 plot._axes.set_ylim(yy0, yy1) 2144 2145 def _enforce_periodic( 2146 self, particle_x, particle_y, x0, x1, period_x, y0, y1, period_y 2147 ): 2148 # duplicate particles if periodic in that direction AND if the plot 2149 # extends outside the domain boundaries. 2150 if self.periodic_x and x0 > self.RE[0]: 2151 particle_x = uhstack((particle_x, particle_x + period_x)) 2152 particle_y = uhstack((particle_y, particle_y)) 2153 if self.periodic_x and x1 < self.LE[0]: 2154 particle_x = uhstack((particle_x, particle_x - period_x)) 2155 particle_y = uhstack((particle_y, particle_y)) 2156 if self.periodic_y and y0 > self.RE[1]: 2157 particle_y = uhstack((particle_y, particle_y + period_y)) 2158 particle_x = uhstack((particle_x, particle_x)) 2159 if self.periodic_y and y1 < self.LE[1]: 2160 particle_y = uhstack((particle_y, particle_y - period_y)) 2161 particle_x = uhstack((particle_x, particle_x)) 2162 return particle_x, particle_y 2163 2164 def _get_region(self, xlim, ylim, axis, data): 2165 LE, RE = [None] * 3, [None] * 3 2166 ds = data.ds 2167 xax = ds.coordinates.x_axis[axis] 2168 yax = ds.coordinates.y_axis[axis] 2169 zax = axis 2170 LE[xax], RE[xax] = xlim 2171 LE[yax], RE[yax] = ylim 2172 LE[zax] = data.center[zax] - self.width * 0.5 2173 LE[zax].convert_to_units("code_length") 2174 RE[zax] = LE[zax] + self.width 2175 if ( 2176 self.region is not None 2177 and np.all(self.region.left_edge <= LE) 2178 and np.all(self.region.right_edge >= RE) 2179 ): 2180 return self.region 2181 self.region = data.ds.region(data.center, LE, RE, data_source=self.data_source) 2182 return self.region 2183 2184 2185class TitleCallback(PlotCallback): 2186 """ 2187 Accepts a *title* and adds it to the plot 2188 """ 2189 2190 _type_name = "title" 2191 2192 def __init__(self, title): 2193 PlotCallback.__init__(self) 2194 self.title = title 2195 2196 def __call__(self, plot): 2197 plot._axes.set_title(self.title) 2198 # Set the font properties of text from this callback to be 2199 # consistent with other text labels in this figure 2200 label = plot._axes.title 2201 self._set_font_properties(plot, [label]) 2202 2203 2204class MeshLinesCallback(PlotCallback): 2205 """ 2206 Adds mesh lines to the plot. Only works for unstructured or 2207 semi-structured mesh data. For structured grid data, see 2208 GridBoundaryCallback or CellEdgesCallback. 2209 2210 Parameters 2211 ---------- 2212 2213 plot_args: dict, optional 2214 A dictionary of arguments that will be passed to matplotlib. 2215 2216 Example 2217 ------- 2218 2219 >>> import yt 2220 >>> ds = yt.load("MOOSE_sample_data/out.e-s010") 2221 >>> sl = yt.SlicePlot(ds, "z", ("connect2", "convected")) 2222 >>> sl.annotate_mesh_lines(plot_args={"color": "black"}) 2223 2224 """ 2225 2226 _type_name = "mesh_lines" 2227 _supported_geometries = ("cartesian", "spectral_cube") 2228 2229 def __init__(self, plot_args=None): 2230 super().__init__() 2231 self.plot_args = plot_args 2232 2233 def promote_2d_to_3d(self, coords, indices, plot): 2234 new_coords = np.zeros((2 * coords.shape[0], 3)) 2235 new_connects = np.zeros( 2236 (indices.shape[0], 2 * indices.shape[1]), dtype=np.int64 2237 ) 2238 2239 new_coords[0 : coords.shape[0], 0:2] = coords 2240 new_coords[0 : coords.shape[0], 2] = plot.ds.domain_left_edge[2] 2241 new_coords[coords.shape[0] :, 0:2] = coords 2242 new_coords[coords.shape[0] :, 2] = plot.ds.domain_right_edge[2] 2243 2244 new_connects[:, 0 : indices.shape[1]] = indices 2245 new_connects[:, indices.shape[1] :] = indices + coords.shape[0] 2246 2247 return new_coords, new_connects 2248 2249 def __call__(self, plot): 2250 2251 index = plot.ds.index 2252 if not issubclass(type(index), UnstructuredIndex): 2253 raise RuntimeError( 2254 "Mesh line annotations only work for " 2255 "unstructured or semi-structured mesh data." 2256 ) 2257 for i, m in enumerate(index.meshes): 2258 try: 2259 ftype, fname = plot.field 2260 if ftype.startswith("connect") and int(ftype[-1]) - 1 != i: 2261 continue 2262 except ValueError: 2263 pass 2264 coords = m.connectivity_coords 2265 indices = m.connectivity_indices - m._index_offset 2266 2267 num_verts = indices.shape[1] 2268 num_dims = coords.shape[1] 2269 2270 if num_dims == 2 and num_verts == 3: 2271 coords, indices = self.promote_2d_to_3d(coords, indices, plot) 2272 elif num_dims == 2 and num_verts == 4: 2273 coords, indices = self.promote_2d_to_3d(coords, indices, plot) 2274 2275 tri_indices = triangulate_indices(indices.astype(np.int_)) 2276 points = coords[tri_indices] 2277 2278 tfc = TriangleFacetsCallback(points, plot_args=self.plot_args) 2279 tfc(plot) 2280 2281 2282class TriangleFacetsCallback(PlotCallback): 2283 """ 2284 Intended for representing a slice of a triangular faceted 2285 geometry in a slice plot. 2286 2287 Uses a set of *triangle_vertices* to find all triangles the plane of a 2288 SlicePlot intersects with. The lines between the intersection points 2289 of the triangles are then added to the plot to create an outline 2290 of the geometry represented by the triangles. 2291 """ 2292 2293 _type_name = "triangle_facets" 2294 _supported_geometries = ("cartesian", "spectral_cube") 2295 2296 def __init__(self, triangle_vertices, plot_args=None): 2297 super().__init__() 2298 self.plot_args = {} if plot_args is None else plot_args 2299 self.vertices = triangle_vertices 2300 2301 def __call__(self, plot): 2302 ax = plot.data.axis 2303 xax = plot.data.ds.coordinates.x_axis[ax] 2304 yax = plot.data.ds.coordinates.y_axis[ax] 2305 2306 if not hasattr(self.vertices, "in_units"): 2307 vertices = plot.data.pf.arr(self.vertices, "code_length") 2308 else: 2309 vertices = self.vertices 2310 l_cy = triangle_plane_intersect(plot.data.axis, plot.data.coord, vertices)[ 2311 :, :, (xax, yax) 2312 ] 2313 # l_cy is shape (nlines, 2, 2) 2314 # reformat for conversion to plot coordinates 2315 l_cy = np.rollaxis(l_cy, 0, 3) 2316 # convert all line starting points 2317 l_cy[0] = self._convert_to_plot(plot, l_cy[0]) 2318 # convert all line ending points 2319 l_cy[1] = self._convert_to_plot(plot, l_cy[1]) 2320 # convert back to shape (nlines, 2, 2) 2321 l_cy = np.rollaxis(l_cy, 2, 0) 2322 # create line collection and add it to the plot 2323 lc = matplotlib.collections.LineCollection(l_cy, **self.plot_args) 2324 plot._axes.add_collection(lc) 2325 2326 2327class TimestampCallback(PlotCallback): 2328 r""" 2329 Annotates the timestamp and/or redshift of the data output at a specified 2330 location in the image (either in a present corner, or by specifying (x,y) 2331 image coordinates with the x_pos, y_pos arguments. If no time_units are 2332 specified, it will automatically choose appropriate units. It allows for 2333 custom formatting of the time and redshift information, as well as the 2334 specification of an inset box around the text. 2335 2336 Parameters 2337 ---------- 2338 2339 x_pos, y_pos : floats, optional 2340 The image location of the timestamp in the coord system defined by the 2341 coord_system kwarg. Setting x_pos and y_pos overrides the corner 2342 parameter. 2343 2344 corner : string, optional 2345 Corner sets up one of 4 predeterimined locations for the timestamp 2346 to be displayed in the image: 'upper_left', 'upper_right', 'lower_left', 2347 'lower_right' (also allows None). This value will be overridden by the 2348 optional x_pos and y_pos keywords. 2349 2350 time : boolean, optional 2351 Whether or not to show the ds.current_time of the data output. Can 2352 be used solo or in conjunction with redshift parameter. 2353 2354 redshift : boolean, optional 2355 Whether or not to show the ds.current_time of the data output. Can 2356 be used solo or in conjunction with the time parameter. 2357 2358 time_format : string, optional 2359 This specifies the format of the time output assuming "time" is the 2360 number of time and "unit" is units of the time (e.g. 's', 'Myr', etc.) 2361 The time can be specified to arbitrary precision according to printf 2362 formatting codes (defaults to .1f -- a float with 1 digits after 2363 decimal). Example: "Age = {time:.2f} {units}". 2364 2365 time_unit : string, optional 2366 time_unit must be a valid yt time unit (e.g. 's', 'min', 'hr', 'yr', 2367 'Myr', etc.) 2368 2369 redshift_format : string, optional 2370 This specifies the format of the redshift output. The redshift can 2371 be specified to arbitrary precision according to printf formatting 2372 codes (defaults to 0.2f -- a float with 2 digits after decimal). 2373 Example: "REDSHIFT = {redshift:03.3g}", 2374 2375 draw_inset_box : boolean, optional 2376 Whether or not an inset box should be included around the text 2377 If so, it uses the inset_box_args to set the matplotlib FancyBboxPatch 2378 object. 2379 2380 coord_system : string, optional 2381 This string defines the coordinate system of the coordinates of pos 2382 Valid coordinates are: 2383 2384 - "data": 3D dataset coordinates 2385 - "plot": 2D coordinates defined by the actual plot limits 2386 - "axis": MPL axis coordinates: (0,0) is lower left; (1,1) is upper right 2387 - "figure": MPL figure coordinates: (0,0) is lower left, (1,1) is upper right 2388 2389 time_offset : float, (value, unit) tuple, or YTQuantity, optional 2390 Apply an offset to the time shown in the annotation from the 2391 value of the current time. If a scalar value with no units is 2392 passed in, the value of the *time_unit* kwarg is used for the 2393 units. Default: None, meaning no offset. 2394 2395 text_args : dictionary, optional 2396 A dictionary of any arbitrary parameters to be passed to the Matplotlib 2397 text object. Defaults: ``{'color':'white', 2398 'horizontalalignment':'center', 'verticalalignment':'top'}``. 2399 2400 inset_box_args : dictionary, optional 2401 A dictionary of any arbitrary parameters to be passed to the Matplotlib 2402 FancyBboxPatch object as the inset box around the text. 2403 Defaults: ``{'boxstyle':'square', 'pad':0.3, 'facecolor':'black', 2404 'linewidth':3, 'edgecolor':'white', 'alpha':0.5}`` 2405 2406 Example 2407 ------- 2408 2409 >>> import yt 2410 >>> ds = yt.load("Enzo_64/DD0020/data0020") 2411 >>> s = yt.SlicePlot(ds, "z", "density") 2412 >>> s.annotate_timestamp() 2413 """ 2414 2415 _type_name = "timestamp" 2416 _supported_geometries = ("cartesian", "spectral_cube", "cylindrical") 2417 2418 def __init__( 2419 self, 2420 x_pos=None, 2421 y_pos=None, 2422 corner="lower_left", 2423 time=True, 2424 redshift=False, 2425 time_format="t = {time:.1f} {units}", 2426 time_unit=None, 2427 redshift_format="z = {redshift:.2f}", 2428 draw_inset_box=False, 2429 coord_system="axis", 2430 time_offset=None, 2431 text_args=None, 2432 inset_box_args=None, 2433 ): 2434 2435 def_text_args = { 2436 "color": "white", 2437 "horizontalalignment": "center", 2438 "verticalalignment": "top", 2439 } 2440 def_inset_box_args = { 2441 "boxstyle": "square,pad=0.3", 2442 "facecolor": "black", 2443 "linewidth": 3, 2444 "edgecolor": "white", 2445 "alpha": 0.5, 2446 } 2447 2448 # Set position based on corner argument. 2449 self.pos = (x_pos, y_pos) 2450 self.corner = corner 2451 self.time = time 2452 self.redshift = redshift 2453 self.time_format = time_format 2454 self.redshift_format = redshift_format 2455 self.time_unit = time_unit 2456 self.coord_system = coord_system 2457 self.time_offset = time_offset 2458 if text_args is None: 2459 text_args = def_text_args 2460 self.text_args = text_args 2461 if inset_box_args is None: 2462 inset_box_args = def_inset_box_args 2463 self.inset_box_args = inset_box_args 2464 2465 # if inset box is not desired, set inset_box_args to {} 2466 if not draw_inset_box: 2467 self.inset_box_args = None 2468 2469 def __call__(self, plot): 2470 # Setting pos overrides corner argument 2471 if self.pos[0] is None or self.pos[1] is None: 2472 if self.corner == "upper_left": 2473 self.pos = (0.03, 0.96) 2474 self.text_args["horizontalalignment"] = "left" 2475 self.text_args["verticalalignment"] = "top" 2476 elif self.corner == "upper_right": 2477 self.pos = (0.97, 0.96) 2478 self.text_args["horizontalalignment"] = "right" 2479 self.text_args["verticalalignment"] = "top" 2480 elif self.corner == "lower_left": 2481 self.pos = (0.03, 0.03) 2482 self.text_args["horizontalalignment"] = "left" 2483 self.text_args["verticalalignment"] = "bottom" 2484 elif self.corner == "lower_right": 2485 self.pos = (0.97, 0.03) 2486 self.text_args["horizontalalignment"] = "right" 2487 self.text_args["verticalalignment"] = "bottom" 2488 elif self.corner is None: 2489 self.pos = (0.5, 0.5) 2490 self.text_args["horizontalalignment"] = "center" 2491 self.text_args["verticalalignment"] = "center" 2492 else: 2493 raise ValueError( 2494 "Argument 'corner' must be set to " 2495 "'upper_left', 'upper_right', 'lower_left', " 2496 "'lower_right', or None" 2497 ) 2498 2499 self.text = "" 2500 2501 # If we're annotating the time, put it in the correct format 2502 if self.time: 2503 # If no time_units are set, then identify a best fit time unit 2504 if self.time_unit is None: 2505 if plot.ds.unit_system._code_flag: 2506 # if the unit system is in code units 2507 # we should not convert to seconds for the plot. 2508 self.time_unit = plot.ds.unit_system.base_units[dimensions.time] 2509 else: 2510 # in the case of non- code units then we 2511 self.time_unit = plot.ds.get_smallest_appropriate_unit( 2512 plot.ds.current_time, quantity="time" 2513 ) 2514 t = plot.ds.current_time.in_units(self.time_unit) 2515 if self.time_offset is not None: 2516 if isinstance(self.time_offset, tuple): 2517 toffset = plot.ds.quan(self.time_offset[0], self.time_offset[1]) 2518 elif isinstance(self.time_offset, Number): 2519 toffset = plot.ds.quan(self.time_offset, self.time_unit) 2520 elif not isinstance(self.time_offset, YTQuantity): 2521 raise RuntimeError( 2522 "'time_offset' must be a float, tuple, or YTQuantity!" 2523 ) 2524 t -= toffset.in_units(self.time_unit) 2525 try: 2526 # here the time unit will be in brackets on the annotation. 2527 un = self.time_unit.latex_representation() 2528 time_unit = r"$\ \ (" + un + r")$" 2529 except AttributeError as err: 2530 if plot.ds.unit_system._code_flag == "code": 2531 raise RuntimeError( 2532 "The time unit str repr didn't match expectations, something is wrong." 2533 ) from err 2534 time_unit = str(self.time_unit).replace("_", " ") 2535 self.text += self.time_format.format(time=float(t), units=time_unit) 2536 2537 # If time and redshift both shown, do one on top of the other 2538 if self.time and self.redshift: 2539 self.text += "\n" 2540 2541 # If we're annotating the redshift, put it in the correct format 2542 if self.redshift: 2543 try: 2544 z = plot.data.ds.current_redshift 2545 except AttributeError: 2546 raise AttributeError( 2547 "Dataset does not have current_redshift. Set redshift=False." 2548 ) 2549 # Replace instances of -0.0* with 0.0* to avoid 2550 # negative null redshifts (e.g., "-0.00"). 2551 self.text += self.redshift_format.format(redshift=float(z)) 2552 self.text = re.sub("-(0.0*)$", r"\g<1>", self.text) 2553 2554 # This is just a fancy wrapper around the TextLabelCallback 2555 tcb = TextLabelCallback( 2556 self.pos, 2557 self.text, 2558 coord_system=self.coord_system, 2559 text_args=self.text_args, 2560 inset_box_args=self.inset_box_args, 2561 ) 2562 return tcb(plot) 2563 2564 2565class ScaleCallback(PlotCallback): 2566 r""" 2567 Annotates the scale of the plot at a specified location in the image 2568 (either in a preset corner, or by specifying (x,y) image coordinates with 2569 the pos argument. Coeff and units (e.g. 1 Mpc or 100 kpc) refer to the 2570 distance scale you desire to show on the plot. If no coeff and units are 2571 specified, an appropriate pair will be determined such that your scale bar 2572 is never smaller than min_frac or greater than max_frac of your plottable 2573 axis length. Additional customization of the scale bar is possible by 2574 adjusting the text_args and size_bar_args dictionaries. The text_args 2575 dictionary accepts matplotlib's font_properties arguments to override 2576 the default font_properties for the current plot. The size_bar_args 2577 dictionary accepts keyword arguments for the AnchoredSizeBar class in 2578 matplotlib's axes_grid toolkit. 2579 2580 Parameters 2581 ---------- 2582 2583 corner : string, optional 2584 Corner sets up one of 4 predeterimined locations for the scale bar 2585 to be displayed in the image: 'upper_left', 'upper_right', 'lower_left', 2586 'lower_right' (also allows None). This value will be overridden by the 2587 optional 'pos' keyword. 2588 2589 coeff : float, optional 2590 The coefficient of the unit defining the distance scale (e.g. 10 kpc or 2591 100 Mpc) for overplotting. If set to None along with unit keyword, 2592 coeff will be automatically determined to be a power of 10 2593 relative to the best-fit unit. 2594 2595 unit : string, optional 2596 unit must be a valid yt distance unit (e.g. 'm', 'km', 'AU', 'pc', 2597 'kpc', etc.) or set to None. If set to None, will be automatically 2598 determined to be the best-fit to the data. 2599 2600 pos : 2- or 3-element tuples, lists, or arrays, optional 2601 The image location of the scale bar in the plot coordinate system. 2602 Setting pos overrides the corner parameter. 2603 2604 min_frac, max_frac: float, optional 2605 The minimum/maximum fraction of the axis width for the scale bar to 2606 extend. A value of 1 would allow the scale bar to extend across the 2607 entire axis width. Only used for automatically calculating 2608 best-fit coeff and unit when neither is specified, otherwise 2609 disregarded. 2610 2611 coord_system : string, optional 2612 This string defines the coordinate system of the coordinates of pos 2613 Valid coordinates are: 2614 2615 - "data": 3D dataset coordinates 2616 - "plot": 2D coordinates defined by the actual plot limits 2617 - "axis": MPL axis coordinates: (0,0) is lower left; (1,1) is upper right 2618 - "figure": MPL figure coordinates: (0,0) is lower left, (1,1) is upper right 2619 2620 text_args : dictionary, optional 2621 A dictionary of parameters to used to update the font_properties 2622 for the text in this callback. For any property not set, it will 2623 use the defaults of the plot. Thus one can modify the text size with 2624 ``text_args={'size':24}`` 2625 2626 size_bar_args : dictionary, optional 2627 A dictionary of parameters to be passed to the Matplotlib 2628 AnchoredSizeBar initializer. 2629 Defaults: ``{'pad': 0.25, 'sep': 5, 'borderpad': 1, 'color': 'w'}`` 2630 2631 draw_inset_box : boolean, optional 2632 Whether or not an inset box should be included around the scale bar. 2633 2634 inset_box_args : dictionary, optional 2635 A dictionary of keyword arguments to be passed to the matplotlib Patch 2636 object that represents the inset box. 2637 Defaults: ``{'facecolor': 'black', 'linewidth': 3, 2638 'edgecolor': 'white', 'alpha': 0.5, 'boxstyle': 'square'}`` 2639 2640 scale_text_format : string, optional 2641 This specifies the format of the scalebar value assuming "scale" is the 2642 numerical value and "unit" is units of the scale (e.g. 'cm', 'kpc', etc.) 2643 The scale can be specified to arbitrary precision according to printf 2644 formatting codes. The format string must only specify "scale" and "units". 2645 Example: "Length = {scale:.2f} {units}". Default: "{scale} {units}" 2646 2647 Example 2648 ------- 2649 2650 >>> import yt 2651 >>> ds = yt.load("Enzo_64/DD0020/data0020") 2652 >>> s = yt.SlicePlot(ds, "z", "density") 2653 >>> s.annotate_scale() 2654 """ 2655 2656 _type_name = "scale" 2657 _supported_geometries = ("cartesian", "spectral_cube", "force") 2658 2659 def __init__( 2660 self, 2661 corner="lower_right", 2662 coeff=None, 2663 unit=None, 2664 pos=None, 2665 max_frac=0.16, 2666 min_frac=0.015, 2667 coord_system="axis", 2668 text_args=None, 2669 size_bar_args=None, 2670 draw_inset_box=False, 2671 inset_box_args=None, 2672 scale_text_format="{scale} {units}", 2673 ): 2674 2675 def_size_bar_args = {"pad": 0.05, "sep": 5, "borderpad": 1, "color": "w"} 2676 2677 def_inset_box_args = { 2678 "facecolor": "black", 2679 "linewidth": 3, 2680 "edgecolor": "white", 2681 "alpha": 0.5, 2682 "boxstyle": "square", 2683 } 2684 2685 # Set position based on corner argument. 2686 self.corner = corner 2687 self.coeff = coeff 2688 self.unit = unit 2689 self.pos = pos 2690 self.max_frac = max_frac 2691 self.min_frac = min_frac 2692 self.coord_system = coord_system 2693 self.scale_text_format = scale_text_format 2694 if size_bar_args is None: 2695 self.size_bar_args = def_size_bar_args 2696 else: 2697 self.size_bar_args = size_bar_args 2698 if inset_box_args is None: 2699 self.inset_box_args = def_inset_box_args 2700 else: 2701 self.inset_box_args = inset_box_args 2702 self.draw_inset_box = draw_inset_box 2703 if text_args is None: 2704 text_args = {} 2705 self.text_args = text_args 2706 2707 def __call__(self, plot): 2708 from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar 2709 2710 # Callback only works for plots with axis ratios of 1 2711 xsize = plot.xlim[1] - plot.xlim[0] 2712 2713 # Setting pos overrides corner argument 2714 if self.pos is None: 2715 if self.corner == "upper_left": 2716 self.pos = (0.11, 0.952) 2717 elif self.corner == "upper_right": 2718 self.pos = (0.89, 0.952) 2719 elif self.corner == "lower_left": 2720 self.pos = (0.11, 0.052) 2721 elif self.corner == "lower_right": 2722 self.pos = (0.89, 0.052) 2723 elif self.corner is None: 2724 self.pos = (0.5, 0.5) 2725 else: 2726 raise ValueError( 2727 "Argument 'corner' must be set to " 2728 "'upper_left', 'upper_right', 'lower_left', " 2729 "'lower_right', or None" 2730 ) 2731 2732 # When identifying a best fit distance unit, do not allow scale marker 2733 # to be greater than max_frac fraction of xaxis or under min_frac 2734 # fraction of xaxis 2735 max_scale = self.max_frac * xsize 2736 min_scale = self.min_frac * xsize 2737 2738 # If no units are set, pick something sensible. 2739 if self.unit is None: 2740 # User has set the axes units and supplied a coefficient. 2741 if plot._axes_unit_names is not None and self.coeff is not None: 2742 self.unit = plot._axes_unit_names[0] 2743 # Nothing provided; identify a best fit distance unit. 2744 else: 2745 min_scale = plot.ds.get_smallest_appropriate_unit( 2746 min_scale, return_quantity=True 2747 ) 2748 max_scale = plot.ds.get_smallest_appropriate_unit( 2749 max_scale, return_quantity=True 2750 ) 2751 if self.coeff is None: 2752 self.coeff = max_scale.v 2753 self.unit = max_scale.units 2754 elif self.coeff is None: 2755 self.coeff = 1 2756 self.scale = plot.ds.quan(self.coeff, self.unit) 2757 text = self.scale_text_format.format(scale=int(self.coeff), units=self.unit) 2758 image_scale = ( 2759 plot.frb.convert_distance_x(self.scale) / plot.frb.convert_distance_x(xsize) 2760 ).v 2761 size_vertical = self.size_bar_args.pop("size_vertical", 0.005 * plot.aspect) 2762 fontproperties = self.size_bar_args.pop( 2763 "fontproperties", plot.font_properties.copy() 2764 ) 2765 frameon = self.size_bar_args.pop("frameon", self.draw_inset_box) 2766 # FontProperties instances use set_<property>() setter functions 2767 for key, val in self.text_args.items(): 2768 setter_func = "set_" + key 2769 try: 2770 getattr(fontproperties, setter_func)(val) 2771 except AttributeError as e: 2772 raise AttributeError( 2773 "Cannot set text_args keyword " 2774 "to include '%s' because MPL's fontproperties object does " 2775 "not contain function '%s'." % (key, setter_func) 2776 ) from e 2777 2778 # this "anchors" the size bar to a box centered on self.pos in axis 2779 # coordinates 2780 self.size_bar_args["bbox_to_anchor"] = self.pos 2781 self.size_bar_args["bbox_transform"] = plot._axes.transAxes 2782 2783 bar = AnchoredSizeBar( 2784 plot._axes.transAxes, 2785 image_scale, 2786 text, 2787 10, 2788 size_vertical=size_vertical, 2789 fontproperties=fontproperties, 2790 frameon=frameon, 2791 **self.size_bar_args, 2792 ) 2793 2794 bar.patch.set(**self.inset_box_args) 2795 2796 plot._axes.add_artist(bar) 2797 2798 return plot 2799 2800 2801class RayCallback(PlotCallback): 2802 """ 2803 Adds a line representing the projected path of a ray across the plot. 2804 The ray can be either a YTOrthoRay, YTRay, or a LightRay object. 2805 annotate_ray() will properly account for periodic rays across the volume. 2806 If arrow is set to True, uses the MPL.pyplot.arrow function, otherwise 2807 uses the MPL.pyplot.plot function to plot a normal line. Adjust 2808 plot_args accordingly. 2809 2810 Parameters 2811 ---------- 2812 2813 ray : YTOrthoRay, YTRay, or LightRay 2814 Ray is the object that we want to include. We overplot the projected 2815 trajectory of the ray. If the object is a trident.LightRay 2816 object, it will only plot the segment of the LightRay that intersects 2817 the dataset currently displayed. 2818 2819 arrow : boolean, optional 2820 Whether or not to place an arrowhead on the front of the ray to denote 2821 direction 2822 Default: False 2823 2824 plot_args : dictionary, optional 2825 A dictionary of any arbitrary parameters to be passed to the Matplotlib 2826 line object. Defaults: {'color':'white', 'linewidth':2}. 2827 2828 Examples 2829 -------- 2830 2831 >>> # Overplot a ray and an ortho_ray object on a projection 2832 >>> import yt 2833 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 2834 >>> oray = ds.ortho_ray(1, (0.3, 0.4)) # orthoray down the y axis 2835 >>> ray = ds.ray((0.1, 0.2, 0.3), (0.6, 0.7, 0.8)) # arbitrary ray 2836 >>> p = yt.ProjectionPlot(ds, "z", "density") 2837 >>> p.annotate_ray(oray) 2838 >>> p.annotate_ray(ray) 2839 >>> p.save() 2840 2841 >>> # Overplot a LightRay object on a projection 2842 >>> import yt 2843 >>> from trident import LightRay 2844 >>> ds = yt.load("enzo_cosmology_plus/RD0004/RD0004") 2845 >>> lr = LightRay( 2846 ... "enzo_cosmology_plus/AMRCosmology.enzo", "Enzo", 0.0, 0.1, time_data=False 2847 ... ) 2848 >>> lray = lr.make_light_ray(seed=1) 2849 >>> p = yt.ProjectionPlot(ds, "z", "density") 2850 >>> p.annotate_ray(lr) 2851 >>> p.save() 2852 2853 """ 2854 2855 _type_name = "ray" 2856 _supported_geometries = ("cartesian", "spectral_cube", "force") 2857 2858 def __init__(self, ray, arrow=False, plot_args=None): 2859 PlotCallback.__init__(self) 2860 def_plot_args = {"color": "white", "linewidth": 2} 2861 self.ray = ray 2862 self.arrow = arrow 2863 if plot_args is None: 2864 plot_args = def_plot_args 2865 self.plot_args = plot_args 2866 2867 def _process_ray(self): 2868 """ 2869 Get the start_coord and end_coord of a ray object 2870 """ 2871 return (self.ray.start_point, self.ray.end_point) 2872 2873 def _process_ortho_ray(self): 2874 """ 2875 Get the start_coord and end_coord of an ortho_ray object 2876 """ 2877 start_coord = self.ray.ds.domain_left_edge.copy() 2878 end_coord = self.ray.ds.domain_right_edge.copy() 2879 2880 xax = self.ray.ds.coordinates.x_axis[self.ray.axis] 2881 yax = self.ray.ds.coordinates.y_axis[self.ray.axis] 2882 start_coord[xax] = end_coord[xax] = self.ray.coords[0] 2883 start_coord[yax] = end_coord[yax] = self.ray.coords[1] 2884 return (start_coord, end_coord) 2885 2886 def _process_light_ray(self, plot): 2887 """ 2888 Get the start_coord and end_coord of a LightRay object. 2889 Identify which of the sections of the LightRay is in the 2890 dataset that is currently being plotted. If there is one, return the 2891 start and end of the corresponding ray segment 2892 """ 2893 2894 for ray_ds in self.ray.light_ray_solution: 2895 if ray_ds["unique_identifier"] == str(plot.ds.unique_identifier): 2896 start_coord = plot.ds.arr(ray_ds["start"]) 2897 end_coord = plot.ds.arr(ray_ds["end"]) 2898 return (start_coord, end_coord) 2899 # if no intersection between the plotted dataset and the LightRay 2900 # return a false tuple to pass to start_coord 2901 return ((False, False), (False, False)) 2902 2903 def __call__(self, plot): 2904 type_name = getattr(self.ray, "_type_name", None) 2905 2906 if type_name == "ray": 2907 start_coord, end_coord = self._process_ray() 2908 2909 elif type_name == "ortho_ray": 2910 start_coord, end_coord = self._process_ortho_ray() 2911 2912 elif hasattr(self.ray, "light_ray_solution"): 2913 start_coord, end_coord = self._process_light_ray(plot) 2914 2915 else: 2916 raise ValueError("ray must be a YTRay, YTOrthoRay, or LightRay object.") 2917 2918 # if start_coord and end_coord are all False, it means no intersecting 2919 # ray segment with this plot. 2920 if not all(start_coord) and not all(end_coord): 2921 return plot 2922 2923 # if possible, break periodic ray into non-periodic 2924 # segments and add each of them individually 2925 if any(plot.ds.periodicity): 2926 segments = periodic_ray( 2927 start_coord.to("code_length"), 2928 end_coord.to("code_length"), 2929 left=plot.ds.domain_left_edge.to("code_length"), 2930 right=plot.ds.domain_right_edge.to("code_length"), 2931 ) 2932 else: 2933 segments = [[start_coord, end_coord]] 2934 2935 # To assure that the last ray segment has an arrow if so desired 2936 # and all other ray segments are lines 2937 for segment in segments[:-1]: 2938 cb = LinePlotCallback( 2939 segment[0], segment[1], coord_system="data", plot_args=self.plot_args 2940 ) 2941 cb(plot) 2942 segment = segments[-1] 2943 if self.arrow: 2944 cb = ArrowCallback( 2945 segment[1], 2946 starting_pos=segment[0], 2947 coord_system="data", 2948 plot_args=self.plot_args, 2949 ) 2950 else: 2951 cb = LinePlotCallback( 2952 segment[0], segment[1], coord_system="data", plot_args=self.plot_args 2953 ) 2954 cb(plot) 2955 return plot 2956 2957 2958class LineIntegralConvolutionCallback(PlotCallback): 2959 """ 2960 Add the line integral convolution to the plot for vector fields 2961 visualization. Two component of vector fields needed to be provided 2962 (i.e., velocity_x and velocity_y, magnetic_field_x and magnetic_field_y). 2963 2964 Parameters 2965 ---------- 2966 2967 field_x, field_y : string 2968 The names of two components of vector field which will be visualized 2969 2970 texture : 2-d array with the same shape of image, optional 2971 Texture will be convolved when computing line integral convolution. 2972 A white noise background will be used as default. 2973 2974 kernellen : float, optional 2975 The lens of kernel for convolution, which is the length over which the 2976 convolution will be performed. For longer kernellen, longer streamline 2977 structure will appear. 2978 2979 lim : 2-element tuple, list, or array, optional 2980 The value of line integral convolution will be clipped to the range 2981 of lim, which applies upper and lower bounds to the values of line 2982 integral convolution and enhance the visibility of plots. Each element 2983 should be in the range of [0,1]. 2984 2985 cmap : string, optional 2986 The name of colormap for line integral convolution plot. 2987 2988 alpha : float, optional 2989 The alpha value for line integral convolution plot. 2990 2991 const_alpha : boolean, optional 2992 If set to False (by default), alpha will be weighted spatially by 2993 the values of line integral convolution; otherwise a constant value 2994 of the given alpha is used. 2995 2996 Example 2997 ------- 2998 2999 >>> import yt 3000 >>> ds = yt.load("Enzo_64/DD0020/data0020") 3001 >>> s = yt.SlicePlot(ds, "z", "density") 3002 >>> s.annotate_line_integral_convolution( 3003 ... "velocity_x", "velocity_y", lim=(0.5, 0.65) 3004 ... ) 3005 """ 3006 3007 _type_name = "line_integral_convolution" 3008 _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") 3009 3010 def __init__( 3011 self, 3012 field_x, 3013 field_y, 3014 texture=None, 3015 kernellen=50.0, 3016 lim=(0.5, 0.6), 3017 cmap="binary", 3018 alpha=0.8, 3019 const_alpha=False, 3020 ): 3021 PlotCallback.__init__(self) 3022 self.field_x = field_x 3023 self.field_y = field_y 3024 self.texture = texture 3025 self.kernellen = kernellen 3026 self.lim = lim 3027 self.cmap = cmap 3028 self.alpha = alpha 3029 self.const_alpha = const_alpha 3030 3031 def __call__(self, plot): 3032 from matplotlib import cm 3033 3034 bounds = self._physical_bounds(plot) 3035 extent = self._plot_bounds(plot) 3036 3037 # We are feeding this size into the pixelizer, where it will properly 3038 # set it in reverse order 3039 nx = plot.image._A.shape[1] 3040 ny = plot.image._A.shape[0] 3041 pixX = plot.data.ds.coordinates.pixelize( 3042 plot.data.axis, plot.data, self.field_x, bounds, (nx, ny) 3043 ) 3044 pixY = plot.data.ds.coordinates.pixelize( 3045 plot.data.axis, plot.data, self.field_y, bounds, (nx, ny) 3046 ) 3047 3048 vectors = np.concatenate((pixX[..., np.newaxis], pixY[..., np.newaxis]), axis=2) 3049 3050 if self.texture is None: 3051 self.texture = np.random.rand(nx, ny).astype(np.double) 3052 elif self.texture.shape != (nx, ny): 3053 raise ValueError( 3054 "'texture' must have the same shape " 3055 "with that of output image (%d, %d)" % (nx, ny) 3056 ) 3057 3058 kernel = np.sin(np.arange(self.kernellen) * np.pi / self.kernellen) 3059 kernel = kernel.astype(np.double) 3060 3061 lic_data = line_integral_convolution_2d(vectors, self.texture, kernel) 3062 lic_data = lic_data / lic_data.max() 3063 lic_data_clip = np.clip(lic_data, self.lim[0], self.lim[1]) 3064 3065 if self.const_alpha: 3066 plot._axes.imshow( 3067 lic_data_clip, 3068 extent=extent, 3069 cmap=self.cmap, 3070 alpha=self.alpha, 3071 origin="lower", 3072 aspect="auto", 3073 ) 3074 else: 3075 lic_data_rgba = cm.ScalarMappable(norm=None, cmap=self.cmap).to_rgba( 3076 lic_data_clip 3077 ) 3078 lic_data_clip_rescale = (lic_data_clip - self.lim[0]) / ( 3079 self.lim[1] - self.lim[0] 3080 ) 3081 lic_data_rgba[..., 3] = lic_data_clip_rescale * self.alpha 3082 plot._axes.imshow( 3083 lic_data_rgba, 3084 extent=extent, 3085 cmap=self.cmap, 3086 origin="lower", 3087 aspect="auto", 3088 ) 3089 3090 return plot 3091 3092 3093class CellEdgesCallback(PlotCallback): 3094 """ 3095 Annotate cell edges. This is done through a second call to pixelize, where 3096 the distance from a pixel to a cell boundary in pixels is compared against 3097 the `line_width` argument. The secondary image is colored as `color` and 3098 overlaid with the `alpha` value. 3099 3100 Parameters 3101 ---------- 3102 line_width : float 3103 The width of the cell edge lines in normalized units relative to the 3104 size of the longest axis. Default is 1% of the size of the smallest 3105 axis. 3106 alpha : float 3107 When the second image is overlaid, it will have this level of alpha 3108 transparency. Default is 1.0 (fully-opaque). 3109 color : tuple of three floats or matplotlib color name 3110 This is the color of the cell edge values. It defaults to black. 3111 3112 Examples 3113 -------- 3114 3115 >>> import yt 3116 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 3117 >>> s = yt.SlicePlot(ds, "z", "density") 3118 >>> s.annotate_cell_edges() 3119 >>> s.save() 3120 """ 3121 3122 _type_name = "cell_edges" 3123 _supported_geometries = ("cartesian", "spectral_cube", "cylindrical") 3124 3125 def __init__(self, line_width=0.002, alpha=1.0, color="black"): 3126 from matplotlib.colors import ColorConverter 3127 3128 conv = ColorConverter() 3129 PlotCallback.__init__(self) 3130 self.line_width = line_width 3131 self.alpha = alpha 3132 self.color = (np.array(conv.to_rgb(color)) * 255).astype("uint8") 3133 3134 def __call__(self, plot): 3135 if plot.data.ds.geometry == "cylindrical" and plot.data.ds.dimensionality == 3: 3136 raise NotImplementedError( 3137 "Cell edge annotation is only supported for \ 3138 for 2D cylindrical geometry, not 3D" 3139 ) 3140 x0, x1, y0, y1 = self._physical_bounds(plot) 3141 xx0, xx1, yy0, yy1 = self._plot_bounds(plot) 3142 nx = plot.image._A.shape[1] 3143 ny = plot.image._A.shape[0] 3144 aspect = float((y1 - y0) / (x1 - x0)) 3145 pixel_aspect = float(ny) / nx 3146 relative_aspect = pixel_aspect / aspect 3147 if relative_aspect > 1: 3148 nx = int(nx / relative_aspect) 3149 else: 3150 ny = int(ny * relative_aspect) 3151 if aspect > 1: 3152 if nx < 1600: 3153 nx = int(1600.0 / nx * ny) 3154 ny = 1600 3155 long_axis = ny 3156 else: 3157 if ny < 1600: 3158 nx = int(1600.0 / ny * nx) 3159 ny = 1600 3160 long_axis = nx 3161 line_width = max(self.line_width * long_axis, 1.0) 3162 im = np.zeros((ny, nx), dtype="f8") 3163 pixelize_cartesian( 3164 im, 3165 plot.data["px"], 3166 plot.data["py"], 3167 plot.data["pdx"], 3168 plot.data["pdy"], 3169 plot.data["px"], # dummy field 3170 (x0, x1, y0, y1), 3171 line_width=line_width, 3172 ) 3173 # New image: 3174 im_buffer = np.zeros((ny, nx, 4), dtype="uint8") 3175 im_buffer[im > 0, 3] = 255 3176 im_buffer[im > 0, :3] = self.color 3177 plot._axes.imshow( 3178 im_buffer, 3179 origin="lower", 3180 interpolation="bilinear", 3181 extent=[xx0, xx1, yy0, yy1], 3182 alpha=self.alpha, 3183 ) 3184 plot._axes.set_xlim(xx0, xx1) 3185 plot._axes.set_ylim(yy0, yy1) 3186