1import base64 2import builtins 3import os 4from collections import OrderedDict 5from functools import wraps 6 7import matplotlib 8import numpy as np 9from more_itertools.more import always_iterable, unzip 10from packaging.version import parse as parse_version 11 12from yt.data_objects.profiles import create_profile, sanitize_field_tuple_keys 13from yt.data_objects.static_output import Dataset 14from yt.frontends.ytdata.data_structures import YTProfileDataset 15from yt.funcs import is_sequence, iter_fields, matplotlib_style_context 16from yt.utilities.exceptions import YTNotInsideNotebook 17from yt.utilities.logger import ytLogger as mylog 18 19from ..data_objects.selection_objects.data_selection_objects import YTSelectionContainer 20from ._commons import validate_image_name 21from .base_plot_types import ImagePlotMPL, PlotMPL 22from .plot_container import ( 23 ImagePlotContainer, 24 get_log_minorticks, 25 invalidate_plot, 26 linear_transform, 27 log_transform, 28 validate_plot, 29) 30 31MPL_VERSION = parse_version(matplotlib.__version__) 32 33 34def invalidate_profile(f): 35 @wraps(f) 36 def newfunc(*args, **kwargs): 37 rv = f(*args, **kwargs) 38 args[0]._profile_valid = False 39 return rv 40 41 return newfunc 42 43 44class PlotContainerDict(OrderedDict): 45 def __missing__(self, key): 46 plot = PlotMPL((10, 8), [0.1, 0.1, 0.8, 0.8], None, None) 47 self[key] = plot 48 return self[key] 49 50 51class FigureContainer(OrderedDict): 52 def __init__(self, plots): 53 self.plots = plots 54 super().__init__() 55 56 def __missing__(self, key): 57 self[key] = self.plots[key].figure 58 return self[key] 59 60 def __iter__(self): 61 return iter(self.plots) 62 63 64class AxesContainer(OrderedDict): 65 def __init__(self, plots): 66 self.plots = plots 67 self.ylim = {} 68 self.xlim = (None, None) 69 super().__init__() 70 71 def __missing__(self, key): 72 self[key] = self.plots[key].axes 73 return self[key] 74 75 def __setitem__(self, key, value): 76 super().__setitem__(key, value) 77 self.ylim[key] = (None, None) 78 79 80def sanitize_label(labels, nprofiles): 81 labels = list(always_iterable(labels)) or [None] 82 83 if len(labels) == 1: 84 labels = labels * nprofiles 85 86 if len(labels) != nprofiles: 87 raise ValueError( 88 f"Number of labels {len(labels)} must match number of profiles {nprofiles}" 89 ) 90 91 invalid_data = [ 92 (label, type(label)) 93 for label in labels 94 if label is not None and not isinstance(label, str) 95 ] 96 if invalid_data: 97 invalid_labels, types = unzip(invalid_data) 98 raise TypeError( 99 "All labels must be None or a string, " 100 f"received {invalid_labels} with type {types}" 101 ) 102 103 return labels 104 105 106def data_object_or_all_data(data_source): 107 if isinstance(data_source, Dataset): 108 data_source = data_source.all_data() 109 110 if not isinstance(data_source, YTSelectionContainer): 111 raise RuntimeError("data_source must be a yt selection data object") 112 113 return data_source 114 115 116class ProfilePlot: 117 r""" 118 Create a 1d profile plot from a data source or from a list 119 of profile objects. 120 121 Given a data object (all_data, region, sphere, etc.), an x field, 122 and a y field (or fields), this will create a one-dimensional profile 123 of the average (or total) value of the y field in bins of the x field. 124 125 This can be used to create profiles from given fields or to plot 126 multiple profiles created from 127 `yt.data_objects.profiles.create_profile`. 128 129 Parameters 130 ---------- 131 data_source : YTSelectionContainer Object 132 The data object to be profiled, such as all_data, region, or 133 sphere. If a dataset is passed in instead, an all_data data object 134 is generated internally from the dataset. 135 x_field : str 136 The binning field for the profile. 137 y_fields : str or list 138 The field or fields to be profiled. 139 weight_field : str 140 The weight field for calculating weighted averages. If None, 141 the profile values are the sum of the field values within the bin. 142 Otherwise, the values are a weighted average. 143 Default : ("gas", "mass") 144 n_bins : int 145 The number of bins in the profile. 146 Default: 64. 147 accumulation : bool 148 If True, the profile values for a bin N are the cumulative sum of 149 all the values from bin 0 to N. 150 Default: False. 151 fractional : If True the profile values are divided by the sum of all 152 the profile data such that the profile represents a probability 153 distribution function. 154 label : str or list of strings 155 If a string, the label to be put on the line plotted. If a list, 156 this should be a list of labels for each profile to be overplotted. 157 Default: None. 158 plot_spec : dict or list of dicts 159 A dictionary or list of dictionaries containing plot keyword 160 arguments. For example, dict(color="red", linestyle=":"). 161 Default: None. 162 x_log : bool 163 Whether the x_axis should be plotted with a logarithmic 164 scaling (True), or linear scaling (False). 165 Default: True. 166 y_log : dict or bool 167 A dictionary containing field:boolean pairs, setting the logarithmic 168 property for that field. May be overridden after instantiation using 169 set_log 170 A single boolean can be passed to signify all fields should use 171 logarithmic (True) or linear scaling (False). 172 Default: True. 173 174 Examples 175 -------- 176 177 This creates profiles of a single dataset. 178 179 >>> import yt 180 >>> ds = yt.load("enzo_tiny_cosmology/DD0046/DD0046") 181 >>> ad = ds.all_data() 182 >>> plot = yt.ProfilePlot( 183 ... ad, 184 ... ("gas", "density"), 185 ... [("gas", "temperature"), ("gas", "velocity_x")], 186 ... weight_field=("gas", "mass"), 187 ... plot_spec=dict(color="red", linestyle="--"), 188 ... ) 189 >>> plot.save() 190 191 This creates profiles from a time series object. 192 193 >>> es = yt.load_simulation("AMRCosmology.enzo", "Enzo") 194 >>> es.get_time_series() 195 196 >>> profiles = [] 197 >>> labels = [] 198 >>> plot_specs = [] 199 >>> for ds in es[-4:]: 200 ... ad = ds.all_data() 201 ... profiles.append( 202 ... create_profile( 203 ... ad, 204 ... [("gas", "density")], 205 ... fields=[("gas", "temperature"), ("gas", "velocity_x")], 206 ... ) 207 ... ) 208 ... labels.append(ds.current_redshift) 209 ... plot_specs.append(dict(linestyle="--", alpha=0.7)) 210 211 >>> plot = yt.ProfilePlot.from_profiles( 212 ... profiles, labels=labels, plot_specs=plot_specs 213 ... ) 214 >>> plot.save() 215 216 Use set_line_property to change line properties of one or all profiles. 217 218 """ 219 220 x_log = None 221 y_log = None 222 x_title = None 223 y_title = None 224 _plot_valid = False 225 226 def __init__( 227 self, 228 data_source, 229 x_field, 230 y_fields, 231 weight_field=("gas", "mass"), 232 n_bins=64, 233 accumulation=False, 234 fractional=False, 235 label=None, 236 plot_spec=None, 237 x_log=True, 238 y_log=True, 239 ): 240 241 data_source = data_object_or_all_data(data_source) 242 y_fields = list(iter_fields(y_fields)) 243 logs = {x_field: bool(x_log)} 244 if isinstance(y_log, bool): 245 y_log = {y_field: y_log for y_field in y_fields} 246 247 if isinstance(data_source.ds, YTProfileDataset): 248 profiles = [data_source.ds.profile] 249 else: 250 profiles = [ 251 create_profile( 252 data_source, 253 [x_field], 254 n_bins=[n_bins], 255 fields=y_fields, 256 weight_field=weight_field, 257 accumulation=accumulation, 258 fractional=fractional, 259 logs=logs, 260 ) 261 ] 262 263 if plot_spec is None: 264 plot_spec = [dict() for p in profiles] 265 if not isinstance(plot_spec, list): 266 plot_spec = [plot_spec.copy() for p in profiles] 267 268 ProfilePlot._initialize_instance(self, profiles, label, plot_spec, y_log) 269 270 @validate_plot 271 def save(self, name=None, suffix=".png", mpl_kwargs=None): 272 r""" 273 Saves a 1d profile plot. 274 275 Parameters 276 ---------- 277 name : str 278 The output file keyword. 279 suffix : string 280 Specify the image type by its suffix. If not specified, the output 281 type will be inferred from the filename. Defaults to PNG. 282 mpl_kwargs : dict 283 A dict of keyword arguments to be passed to matplotlib. 284 """ 285 if not self._plot_valid: 286 self._setup_plots() 287 unique = set(self.plots.values()) 288 if len(unique) < len(self.plots): 289 iters = zip(range(len(unique)), sorted(unique)) 290 else: 291 iters = self.plots.items() 292 293 if name is None: 294 if len(self.profiles) == 1: 295 name = str(self.profiles[0].ds) 296 else: 297 name = "Multi-data" 298 299 name = validate_image_name(name, suffix) 300 prefix, suffix = os.path.splitext(name) 301 302 xfn = self.profiles[0].x_field 303 if isinstance(xfn, tuple): 304 xfn = xfn[1] 305 306 names = [] 307 for uid, plot in iters: 308 if isinstance(uid, tuple): 309 uid = uid[1] 310 uid_name = f"{prefix}_1d-Profile_{xfn}_{uid}{suffix}" 311 names.append(uid_name) 312 mylog.info("Saving %s", uid_name) 313 with matplotlib_style_context(): 314 plot.save(uid_name, mpl_kwargs=mpl_kwargs) 315 return names 316 317 @validate_plot 318 def show(self): 319 r"""This will send any existing plots to the IPython notebook. 320 321 If yt is being run from within an IPython session, and it is able to 322 determine this, this function will send any existing plots to the 323 notebook for display. 324 325 If yt can't determine if it's inside an IPython session, it will raise 326 YTNotInsideNotebook. 327 328 Examples 329 -------- 330 331 >>> import yt 332 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 333 >>> pp = ProfilePlot(ds.all_data(), ("gas", "density"), ("gas", "temperature")) 334 >>> pp.show() 335 336 """ 337 if "__IPYTHON__" in dir(builtins): 338 from IPython.display import display 339 340 display(self) 341 else: 342 raise YTNotInsideNotebook 343 344 @validate_plot 345 def _repr_html_(self): 346 """Return an html representation of the plot object. Will display as a 347 png for each WindowPlotMPL instance in self.plots""" 348 ret = "" 349 unique = set(self.plots.values()) 350 if len(unique) < len(self.plots): 351 iters = sorted(unique) 352 else: 353 iters = self.plots.values() 354 for plot in iters: 355 with matplotlib_style_context(): 356 img = plot._repr_png_() 357 img = base64.b64encode(img).decode() 358 ret += ( 359 r'<img style="max-width:100%;max-height:100%;" ' 360 r'src="data:image/png;base64,{}"><br>'.format(img) 361 ) 362 return ret 363 364 def _setup_plots(self): 365 if self._plot_valid: 366 return 367 for f in self.axes: 368 self.axes[f].cla() 369 if f in self._plot_text: 370 self.plots[f].axes.text( 371 self._text_xpos[f], 372 self._text_ypos[f], 373 self._plot_text[f], 374 fontproperties=self._font_properties, 375 **self._text_kwargs[f], 376 ) 377 378 for i, profile in enumerate(self.profiles): 379 for field, field_data in profile.items(): 380 self.axes[field].plot( 381 np.array(profile.x), 382 np.array(field_data), 383 label=self.label[i], 384 **self.plot_spec[i], 385 ) 386 387 for profile in self.profiles: 388 for fname in profile.keys(): 389 axes = self.axes[fname] 390 xscale, yscale = self._get_field_log(fname, profile) 391 xtitle, ytitle = self._get_field_title(fname, profile) 392 393 axes.set_xscale(xscale) 394 axes.set_yscale(yscale) 395 396 axes.set_ylabel(ytitle) 397 axes.set_xlabel(xtitle) 398 399 axes.set_ylim(*self.axes.ylim[fname]) 400 axes.set_xlim(*self.axes.xlim) 401 402 if fname in self._plot_title: 403 axes.set_title(self._plot_title[fname]) 404 405 if any(self.label): 406 axes.legend(loc="best") 407 self._set_font_properties() 408 self._plot_valid = True 409 410 @classmethod 411 def _initialize_instance(cls, obj, profiles, labels, plot_specs, y_log): 412 obj._plot_title = {} 413 obj._plot_text = {} 414 obj._text_xpos = {} 415 obj._text_ypos = {} 416 obj._text_kwargs = {} 417 418 from matplotlib.font_manager import FontProperties 419 420 obj._font_properties = FontProperties(family="stixgeneral", size=18) 421 obj._font_color = None 422 obj.profiles = list(always_iterable(profiles)) 423 obj.x_log = None 424 obj.y_log = sanitize_field_tuple_keys(y_log, obj.profiles[0].data_source) or {} 425 obj.y_title = {} 426 obj.x_title = None 427 obj.label = sanitize_label(labels, len(obj.profiles)) 428 if plot_specs is None: 429 plot_specs = [dict() for p in obj.profiles] 430 obj.plot_spec = plot_specs 431 obj.plots = PlotContainerDict() 432 obj.figures = FigureContainer(obj.plots) 433 obj.axes = AxesContainer(obj.plots) 434 obj._setup_plots() 435 return obj 436 437 @classmethod 438 def from_profiles(cls, profiles, labels=None, plot_specs=None, y_log=None): 439 r""" 440 Instantiate a ProfilePlot object from a list of profiles 441 created with :func:`~yt.data_objects.profiles.create_profile`. 442 443 Parameters 444 ---------- 445 profiles : a profile or list of profiles 446 A single profile or list of profile objects created with 447 :func:`~yt.data_objects.profiles.create_profile`. 448 labels : list of strings 449 A list of labels for each profile to be overplotted. 450 Default: None. 451 plot_specs : list of dicts 452 A list of dictionaries containing plot keyword 453 arguments. For example, [dict(color="red", linestyle=":")]. 454 Default: None. 455 456 Examples 457 -------- 458 459 >>> from yt import load_simulation 460 >>> es = load_simulation("AMRCosmology.enzo", "Enzo") 461 >>> es.get_time_series() 462 463 >>> profiles = [] 464 >>> labels = [] 465 >>> plot_specs = [] 466 >>> for ds in es[-4:]: 467 ... ad = ds.all_data() 468 ... profiles.append( 469 ... create_profile( 470 ... ad, 471 ... [("gas", "density")], 472 ... fields=[("gas", "temperature"), ("gas", "velocity_x")], 473 ... ) 474 ... ) 475 ... labels.append(ds.current_redshift) 476 ... plot_specs.append(dict(linestyle="--", alpha=0.7)) 477 >>> plot = ProfilePlot.from_profiles( 478 ... profiles, labels=labels, plot_specs=plot_specs 479 ... ) 480 >>> plot.save() 481 482 """ 483 if labels is not None and len(profiles) != len(labels): 484 raise RuntimeError("Profiles list and labels list must be the same size.") 485 if plot_specs is not None and len(plot_specs) != len(profiles): 486 raise RuntimeError( 487 "Profiles list and plot_specs list must be the same size." 488 ) 489 obj = cls.__new__(cls) 490 return cls._initialize_instance(obj, profiles, labels, plot_specs, y_log) 491 492 @invalidate_plot 493 def set_line_property(self, property, value, index=None): 494 r""" 495 Set properties for one or all lines to be plotted. 496 497 Parameters 498 ---------- 499 property : str 500 The line property to be set. 501 value : str, int, float 502 The value to set for the line property. 503 index : int 504 The index of the profile in the list of profiles to be 505 changed. If None, change all plotted lines. 506 Default : None. 507 508 Examples 509 -------- 510 511 Change all the lines in a plot 512 plot.set_line_property("linestyle", "-") 513 514 Change a single line. 515 plot.set_line_property("linewidth", 4, index=0) 516 517 """ 518 if index is None: 519 specs = self.plot_spec 520 else: 521 specs = [self.plot_spec[index]] 522 for spec in specs: 523 spec[property] = value 524 return self 525 526 @invalidate_plot 527 def set_log(self, field, log): 528 """set a field to log or linear. 529 530 Parameters 531 ---------- 532 field : string 533 the field to set a transform 534 log : boolean 535 Log on/off. 536 """ 537 if field == "all": 538 self.x_log = log 539 for field in list(self.profiles[0].field_data.keys()): 540 self.y_log[field] = log 541 else: 542 (field,) = self.profiles[0].data_source._determine_fields([field]) 543 if field == self.profiles[0].x_field: 544 self.x_log = log 545 elif field in self.profiles[0].field_data: 546 self.y_log[field] = log 547 else: 548 raise KeyError(f"Field {field} not in profile plot!") 549 return self 550 551 @invalidate_plot 552 def set_ylabel(self, field, label): 553 """Sets a new ylabel for the specified fields 554 555 Parameters 556 ---------- 557 field : string 558 The name of the field that is to be changed. 559 560 label : string 561 The label to be placed on the y-axis 562 """ 563 if field == "all": 564 for field in self.profiles[0].field_data: 565 self.y_title[field] = label 566 else: 567 (field,) = self.profiles[0].data_source._determine_fields([field]) 568 if field in self.profiles[0].field_data: 569 self.y_title[field] = label 570 else: 571 raise KeyError(f"Field {field} not in profile plot!") 572 573 return self 574 575 @invalidate_plot 576 def set_xlabel(self, label): 577 """Sets a new xlabel for all profiles 578 579 Parameters 580 ---------- 581 label : string 582 The label to be placed on the x-axis 583 """ 584 self.x_title = label 585 586 return self 587 588 @invalidate_plot 589 def set_unit(self, field, unit): 590 """Sets a new unit for the requested field 591 592 Parameters 593 ---------- 594 field : string 595 The name of the field that is to be changed. 596 597 unit : string or Unit object 598 The name of the new unit. 599 """ 600 fd = self.profiles[0].data_source._determine_fields(field)[0] 601 for profile in self.profiles: 602 if fd == profile.x_field: 603 profile.set_x_unit(unit) 604 elif fd[1] in self.profiles[0].field_map: 605 profile.set_field_unit(field, unit) 606 else: 607 raise KeyError(f"Field {field} not in profile plot!") 608 return self 609 610 @invalidate_plot 611 def set_xlim(self, xmin=None, xmax=None): 612 """Sets the limits of the bin field 613 614 Parameters 615 ---------- 616 617 xmin : float or None 618 The new x minimum. Defaults to None, which leaves the xmin 619 unchanged. 620 621 xmax : float or None 622 The new x maximum. Defaults to None, which leaves the xmax 623 unchanged. 624 625 Examples 626 -------- 627 628 >>> import yt 629 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 630 >>> pp = yt.ProfilePlot( 631 ... ds.all_data(), ("gas", "density"), ("gas", "temperature") 632 ... ) 633 >>> pp.set_xlim(1e-29, 1e-24) 634 >>> pp.save() 635 636 """ 637 self.axes.xlim = (xmin, xmax) 638 for i, p in enumerate(self.profiles): 639 if xmin is None: 640 xmi = p.x_bins.min() 641 else: 642 xmi = xmin 643 if xmax is None: 644 xma = p.x_bins.max() 645 else: 646 xma = xmax 647 extrema = {p.x_field: ((xmi, str(p.x.units)), (xma, str(p.x.units)))} 648 units = {p.x_field: str(p.x.units)} 649 if self.x_log is None: 650 logs = None 651 else: 652 logs = {p.x_field: self.x_log} 653 for field in p.field_map.values(): 654 units[field] = str(p.field_data[field].units) 655 self.profiles[i] = create_profile( 656 p.data_source, 657 p.x_field, 658 n_bins=len(p.x_bins) - 1, 659 fields=list(p.field_map.values()), 660 weight_field=p.weight_field, 661 accumulation=p.accumulation, 662 fractional=p.fractional, 663 logs=logs, 664 extrema=extrema, 665 units=units, 666 ) 667 return self 668 669 @invalidate_plot 670 def set_ylim(self, field, ymin=None, ymax=None): 671 """Sets the plot limits for the specified field we are binning. 672 673 Parameters 674 ---------- 675 676 field : string or field tuple 677 678 The field that we want to adjust the plot limits for. 679 680 ymin : float or None 681 The new y minimum. Defaults to None, which leaves the ymin 682 unchanged. 683 684 ymax : float or None 685 The new y maximum. Defaults to None, which leaves the ymax 686 unchanged. 687 688 Examples 689 -------- 690 691 >>> import yt 692 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 693 >>> pp = yt.ProfilePlot( 694 ... ds.all_data(), 695 ... ("gas", "density"), 696 ... [("gas", "temperature"), ("gas", "velocity_x")], 697 ... ) 698 >>> pp.set_ylim(("gas", "temperature"), 1e4, 1e6) 699 >>> pp.save() 700 701 """ 702 fields = list(self.axes.keys()) if field == "all" else field 703 for profile in self.profiles: 704 for field in profile.data_source._determine_fields(fields): 705 if field in profile.field_map: 706 field = profile.field_map[field] 707 self.axes.ylim[field] = (ymin, ymax) 708 # Continue on to the next profile. 709 break 710 return self 711 712 def _set_font_properties(self): 713 for f in self.plots: 714 self.plots[f]._set_font_properties(self._font_properties, self._font_color) 715 716 def _get_field_log(self, field_y, profile): 717 yfi = profile.field_info[field_y] 718 if self.x_log is None: 719 x_log = profile.x_log 720 else: 721 x_log = self.x_log 722 y_log = self.y_log.get(field_y, yfi.take_log) 723 scales = {True: "log", False: "linear"} 724 return scales[x_log], scales[y_log] 725 726 def _get_field_label(self, field, field_info, field_unit, fractional=False): 727 field_unit = field_unit.latex_representation() 728 field_name = field_info.display_name 729 if isinstance(field, tuple): 730 field = field[1] 731 if field_name is None: 732 field_name = r"$\rm{" + field + r"}$" 733 field_name = r"$\rm{" + field.replace("_", r"\ ").title() + r"}$" 734 elif field_name.find("$") == -1: 735 field_name = field_name.replace(" ", r"\ ") 736 field_name = r"$\rm{" + field_name + r"}$" 737 if fractional: 738 label = field_name + r"$\rm{\ Probability\ Density}$" 739 elif field_unit is None or field_unit == "": 740 label = field_name 741 else: 742 label = field_name + r"$\ \ (" + field_unit + r")$" 743 return label 744 745 def _get_field_title(self, field_y, profile): 746 field_x = profile.x_field 747 xfi = profile.field_info[field_x] 748 yfi = profile.field_info[field_y] 749 x_unit = profile.x.units 750 y_unit = profile.field_units[field_y] 751 fractional = profile.fractional 752 x_title = self.x_title or self._get_field_label(field_x, xfi, x_unit) 753 y_title = self.y_title.get(field_y, None) or self._get_field_label( 754 field_y, yfi, y_unit, fractional 755 ) 756 757 return (x_title, y_title) 758 759 @invalidate_plot 760 def annotate_title(self, title, field="all"): 761 r"""Set a title for the plot. 762 763 Parameters 764 ---------- 765 title : str 766 The title to add. 767 field : str or list of str 768 The field name for which title needs to be set. 769 770 Examples 771 -------- 772 >>> # To set title for all the fields: 773 >>> plot.annotate_title("This is a Profile Plot") 774 775 >>> # To set title for specific fields: 776 >>> plot.annotate_title("Profile Plot for Temperature", ("gas", "temperature")) 777 778 >>> # Setting same plot title for both the given fields 779 >>> plot.annotate_title( 780 ... "Profile Plot: Temperature-Dark Matter Density", 781 ... [("gas", "temperature"), ("deposit", "dark_matter_density")], 782 ... ) 783 784 """ 785 fields = list(self.axes.keys()) if field == "all" else field 786 for profile in self.profiles: 787 for field in profile.data_source._determine_fields(fields): 788 if field in profile.field_map: 789 field = profile.field_map[field] 790 self._plot_title[field] = title 791 return self 792 793 @invalidate_plot 794 def annotate_text(self, xpos=0.0, ypos=0.0, text=None, field="all", **text_kwargs): 795 r"""Allow the user to insert text onto the plot 796 797 The x-position and y-position must be given as well as the text string. 798 Add *text* to plot at location *xpos*, *ypos* in plot coordinates for 799 the given fields or by default for all fields. 800 (see example below). 801 802 Parameters 803 ---------- 804 xpos : float 805 Position on plot in x-coordinates. 806 ypos : float 807 Position on plot in y-coordinates. 808 text : str 809 The text to insert onto the plot. 810 field : str or tuple 811 The name of the field to add text to. 812 **text_kwargs : dict 813 Extra keyword arguments will be passed to matplotlib text instance 814 815 >>> import yt 816 >>> from yt.units import kpc 817 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 818 >>> my_galaxy = ds.disk(ds.domain_center, [0.0, 0.0, 1.0], 10 * kpc, 3 * kpc) 819 >>> plot = yt.ProfilePlot( 820 ... my_galaxy, ("gas", "density"), [("gas", "temperature")] 821 ... ) 822 823 >>> # Annotate text for all the fields 824 >>> plot.annotate_text(1e-26, 1e5, "This is annotated text in the plot area.") 825 >>> plot.save() 826 827 >>> # Annotate text for a given field 828 >>> plot.annotate_text(1e-26, 1e5, "Annotated text", ("gas", "temperature")) 829 >>> plot.save() 830 831 >>> # Annotate text for multiple fields 832 >>> fields = [("gas", "temperature"), ("gas", "density")] 833 >>> plot.annotate_text(1e-26, 1e5, "Annotated text", fields) 834 >>> plot.save() 835 836 """ 837 fields = list(self.axes.keys()) if field == "all" else field 838 for profile in self.profiles: 839 for field in profile.data_source._determine_fields(fields): 840 if field in profile.field_map: 841 field = profile.field_map[field] 842 self._plot_text[field] = text 843 self._text_xpos[field] = xpos 844 self._text_ypos[field] = ypos 845 self._text_kwargs[field] = text_kwargs 846 return self 847 848 849class PhasePlot(ImagePlotContainer): 850 r""" 851 Create a 2d profile (phase) plot from a data source or from 852 profile object created with 853 `yt.data_objects.profiles.create_profile`. 854 855 Given a data object (all_data, region, sphere, etc.), an x field, 856 y field, and z field (or fields), this will create a two-dimensional 857 profile of the average (or total) value of the z field in bins of the 858 x and y fields. 859 860 Parameters 861 ---------- 862 data_source : YTSelectionContainer Object 863 The data object to be profiled, such as all_data, region, or 864 sphere. If a dataset is passed in instead, an all_data data object 865 is generated internally from the dataset. 866 x_field : str 867 The x binning field for the profile. 868 y_field : str 869 The y binning field for the profile. 870 z_fields : str or list 871 The field or fields to be profiled. 872 weight_field : str 873 The weight field for calculating weighted averages. If None, 874 the profile values are the sum of the field values within the bin. 875 Otherwise, the values are a weighted average. 876 Default : ("gas", "mass") 877 x_bins : int 878 The number of bins in x field for the profile. 879 Default: 128. 880 y_bins : int 881 The number of bins in y field for the profile. 882 Default: 128. 883 accumulation : bool or list of bools 884 If True, the profile values for a bin n are the cumulative sum of 885 all the values from bin 0 to n. If -True, the sum is reversed so 886 that the value for bin n is the cumulative sum from bin N (total bins) 887 to n. A list of values can be given to control the summation in each 888 dimension independently. 889 Default: False. 890 fractional : If True the profile values are divided by the sum of all 891 the profile data such that the profile represents a probability 892 distribution function. 893 fontsize : int 894 Font size for all text in the plot. 895 Default: 18. 896 figure_size : int 897 Size in inches of the image. 898 Default: 8 (8x8) 899 shading : str 900 This argument is directly passed down to matplotlib.axes.Axes.pcolormesh 901 see 902 https://matplotlib.org/3.3.1/gallery/images_contours_and_fields/pcolormesh_grids.html#sphx-glr-gallery-images-contours-and-fields-pcolormesh-grids-py # noqa 903 Default: 'nearest' 904 905 Examples 906 -------- 907 908 >>> import yt 909 >>> ds = yt.load("enzo_tiny_cosmology/DD0046/DD0046") 910 >>> ad = ds.all_data() 911 >>> plot = yt.PhasePlot( 912 ... ad, 913 ... ("gas", "density"), 914 ... ("gas", "temperature"), 915 ... [("gas", "mass")], 916 ... weight_field=None, 917 ... ) 918 >>> plot.save() 919 920 >>> # Change plot properties. 921 >>> plot.set_cmap(("gas", "mass"), "jet") 922 >>> plot.set_zlim(("gas", "mass"), 1e8, 1e13) 923 >>> plot.annotate_title("This is a phase plot") 924 925 """ 926 x_log = None 927 y_log = None 928 plot_title = None 929 _plot_valid = False 930 _profile_valid = False 931 _plot_type = "Phase" 932 _xlim = (None, None) 933 _ylim = (None, None) 934 935 def __init__( 936 self, 937 data_source, 938 x_field, 939 y_field, 940 z_fields, 941 weight_field=("gas", "mass"), 942 x_bins=128, 943 y_bins=128, 944 accumulation=False, 945 fractional=False, 946 fontsize=18, 947 figure_size=8.0, 948 shading="nearest", 949 ): 950 951 data_source = data_object_or_all_data(data_source) 952 953 if isinstance(z_fields, tuple): 954 z_fields = [z_fields] 955 z_fields = list(always_iterable(z_fields)) 956 957 if isinstance(data_source.ds, YTProfileDataset): 958 profile = data_source.ds.profile 959 else: 960 profile = create_profile( 961 data_source, 962 [x_field, y_field], 963 z_fields, 964 n_bins=[x_bins, y_bins], 965 weight_field=weight_field, 966 accumulation=accumulation, 967 fractional=fractional, 968 ) 969 970 type(self)._initialize_instance( 971 self, data_source, profile, fontsize, figure_size, shading 972 ) 973 974 @classmethod 975 def _initialize_instance( 976 cls, obj, data_source, profile, fontsize, figure_size, shading 977 ): 978 obj.plot_title = {} 979 obj.z_log = {} 980 obj.z_title = {} 981 obj._initfinished = False 982 obj.x_log = None 983 obj.y_log = None 984 obj._plot_text = {} 985 obj._text_xpos = {} 986 obj._text_ypos = {} 987 obj._text_kwargs = {} 988 obj._profile = profile 989 obj._shading = shading 990 obj._profile_valid = True 991 obj._xlim = (None, None) 992 obj._ylim = (None, None) 993 super(PhasePlot, obj).__init__(data_source, figure_size, fontsize) 994 obj._setup_plots() 995 obj._initfinished = True 996 return obj 997 998 def _get_field_title(self, field_z, profile): 999 field_x = profile.x_field 1000 field_y = profile.y_field 1001 xfi = profile.field_info[field_x] 1002 yfi = profile.field_info[field_y] 1003 zfi = profile.field_info[field_z] 1004 x_unit = profile.x.units 1005 y_unit = profile.y.units 1006 z_unit = profile.field_units[field_z] 1007 fractional = profile.fractional 1008 x_label, y_label, z_label = self._get_axes_labels(field_z) 1009 x_title = x_label or self._get_field_label(field_x, xfi, x_unit) 1010 y_title = y_label or self._get_field_label(field_y, yfi, y_unit) 1011 z_title = z_label or self._get_field_label(field_z, zfi, z_unit, fractional) 1012 return (x_title, y_title, z_title) 1013 1014 def _get_field_label(self, field, field_info, field_unit, fractional=False): 1015 field_unit = field_unit.latex_representation() 1016 field_name = field_info.display_name 1017 if isinstance(field, tuple): 1018 field = field[1] 1019 if field_name is None: 1020 field_name = r"$\rm{" + field + r"}$" 1021 field_name = r"$\rm{" + field.replace("_", r"\ ").title() + r"}$" 1022 elif field_name.find("$") == -1: 1023 field_name = field_name.replace(" ", r"\ ") 1024 field_name = r"$\rm{" + field_name + r"}$" 1025 if fractional: 1026 label = field_name + r"$\rm{\ Probability\ Density}$" 1027 elif field_unit is None or field_unit == "": 1028 label = field_name 1029 else: 1030 label = field_name + r"$\ \ (" + field_unit + r")$" 1031 return label 1032 1033 def _get_field_log(self, field_z, profile): 1034 zfi = profile.field_info[field_z] 1035 if self.x_log is None: 1036 x_log = profile.x_log 1037 else: 1038 x_log = self.x_log 1039 if self.y_log is None: 1040 y_log = profile.y_log 1041 else: 1042 y_log = self.y_log 1043 if field_z in self.z_log: 1044 z_log = self.z_log[field_z] 1045 else: 1046 z_log = zfi.take_log 1047 scales = {True: "log", False: "linear"} 1048 return scales[x_log], scales[y_log], scales[z_log] 1049 1050 def _recreate_frb(self): 1051 # needed for API compatibility with PlotWindow 1052 pass 1053 1054 @property 1055 def profile(self): 1056 if not self._profile_valid: 1057 self._recreate_profile() 1058 return self._profile 1059 1060 @property 1061 def fields(self): 1062 return list(self.plots.keys()) 1063 1064 def _setup_plots(self): 1065 if self._plot_valid: 1066 return 1067 for f, data in self.profile.items(): 1068 fig = None 1069 axes = None 1070 cax = None 1071 draw_colorbar = True 1072 draw_axes = True 1073 zlim = (None, None) 1074 xlim = self._xlim 1075 ylim = self._ylim 1076 if f in self.plots: 1077 draw_colorbar = self.plots[f]._draw_colorbar 1078 draw_axes = self.plots[f]._draw_axes 1079 zlim = (self.plots[f].zmin, self.plots[f].zmax) 1080 if self.plots[f].figure is not None: 1081 fig = self.plots[f].figure 1082 axes = self.plots[f].axes 1083 cax = self.plots[f].cax 1084 1085 x_scale, y_scale, z_scale = self._get_field_log(f, self.profile) 1086 x_title, y_title, z_title = self._get_field_title(f, self.profile) 1087 1088 if zlim == (None, None): 1089 if z_scale == "log": 1090 positive_values = data[data > 0.0] 1091 if len(positive_values) == 0: 1092 mylog.warning( 1093 "Profiled field %s has no positive values. Max = %f.", 1094 f, 1095 np.nanmax(data), 1096 ) 1097 mylog.warning("Switching to linear colorbar scaling.") 1098 zmin = np.nanmin(data) 1099 z_scale = "linear" 1100 self._field_transform[f] = linear_transform 1101 else: 1102 zmin = positive_values.min() 1103 self._field_transform[f] = log_transform 1104 else: 1105 zmin = np.nanmin(data) 1106 self._field_transform[f] = linear_transform 1107 zlim = [zmin, np.nanmax(data)] 1108 1109 font_size = self._font_properties.get_size() 1110 f = self.profile.data_source._determine_fields(f)[0] 1111 1112 # if this is a Particle Phase Plot AND if we using a single color, 1113 # override the colorbar here. 1114 splat_color = getattr(self, "splat_color", None) 1115 if splat_color is not None: 1116 cmap = matplotlib.colors.ListedColormap(splat_color, "dummy") 1117 else: 1118 cmap = self._colormap_config[f] 1119 1120 self.plots[f] = PhasePlotMPL( 1121 self.profile.x, 1122 self.profile.y, 1123 data, 1124 x_scale, 1125 y_scale, 1126 z_scale, 1127 cmap, 1128 zlim, 1129 self.figure_size, 1130 font_size, 1131 fig, 1132 axes, 1133 cax, 1134 shading=self._shading, 1135 ) 1136 1137 self.plots[f]._toggle_axes(draw_axes) 1138 self.plots[f]._toggle_colorbar(draw_colorbar) 1139 1140 self.plots[f].axes.xaxis.set_label_text(x_title) 1141 self.plots[f].axes.yaxis.set_label_text(y_title) 1142 self.plots[f].cax.yaxis.set_label_text(z_title) 1143 1144 self.plots[f].axes.set_xlim(xlim) 1145 self.plots[f].axes.set_ylim(ylim) 1146 1147 color = self._background_color[f] 1148 1149 self.plots[f].axes.set_facecolor(color) 1150 1151 if f in self._plot_text: 1152 self.plots[f].axes.text( 1153 self._text_xpos[f], 1154 self._text_ypos[f], 1155 self._plot_text[f], 1156 fontproperties=self._font_properties, 1157 **self._text_kwargs[f], 1158 ) 1159 1160 if f in self.plot_title: 1161 self.plots[f].axes.set_title(self.plot_title[f]) 1162 1163 # x-y axes minorticks 1164 if f not in self._minorticks: 1165 self._minorticks[f] = True 1166 if self._minorticks[f]: 1167 self.plots[f].axes.minorticks_on() 1168 else: 1169 self.plots[f].axes.minorticks_off() 1170 1171 # colorbar minorticks 1172 if f not in self._cbar_minorticks: 1173 self._cbar_minorticks[f] = True 1174 if self._cbar_minorticks[f]: 1175 if self._field_transform[f] == linear_transform: 1176 self.plots[f].cax.minorticks_on() 1177 elif MPL_VERSION < parse_version("3.0.0"): 1178 # before matplotlib 3 log-scaled colorbars internally used 1179 # a linear scale going from zero to one and did not draw 1180 # minor ticks. Since we want minor ticks, calculate 1181 # where the minor ticks should go in this linear scale 1182 # and add them manually. 1183 vmin = np.float64(self.plots[f].cb.norm.vmin) 1184 vmax = np.float64(self.plots[f].cb.norm.vmax) 1185 mticks = self.plots[f].image.norm(get_log_minorticks(vmin, vmax)) 1186 self.plots[f].cax.yaxis.set_ticks(mticks, minor=True) 1187 else: 1188 self.plots[f].cax.minorticks_off() 1189 1190 self._set_font_properties() 1191 1192 # if this is a particle plot with one color only, hide the cbar here 1193 if hasattr(self, "use_cbar") and not self.use_cbar: 1194 self.plots[f].hide_colorbar() 1195 1196 self._plot_valid = True 1197 1198 @classmethod 1199 def from_profile(cls, profile, fontsize=18, figure_size=8.0, shading="nearest"): 1200 r""" 1201 Instantiate a PhasePlot object from a profile object created 1202 with :func:`~yt.data_objects.profiles.create_profile`. 1203 1204 Parameters 1205 ---------- 1206 profile : An instance of :class:`~yt.data_objects.profiles.ProfileND` 1207 A single profile object. 1208 fontsize : float 1209 The fontsize to use, in points. 1210 figure_size : float 1211 The figure size to use, in inches. 1212 shading : str 1213 This argument is directly passed down to matplotlib.axes.Axes.pcolormesh 1214 see 1215 https://matplotlib.org/3.3.1/gallery/images_contours_and_fields/pcolormesh_grids.html#sphx-glr-gallery-images-contours-and-fields-pcolormesh-grids-py # noqa 1216 Default: 'nearest' 1217 1218 Examples 1219 -------- 1220 1221 >>> import yt 1222 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1223 >>> extrema = { 1224 ... ("gas", "density"): (1e-31, 1e-24), 1225 ... ("gas", "temperature"): (1e1, 1e8), 1226 ... ("gas", "mass"): (1e-6, 1e-1), 1227 ... } 1228 >>> profile = yt.create_profile( 1229 ... ds.all_data(), 1230 ... [("gas", "density"), ("gas", "temperature")], 1231 ... fields=[("gas", "mass")], 1232 ... extrema=extrema, 1233 ... fractional=True, 1234 ... ) 1235 >>> ph = yt.PhasePlot.from_profile(profile) 1236 >>> ph.save() 1237 """ 1238 obj = cls.__new__(cls) 1239 data_source = profile.data_source 1240 return cls._initialize_instance( 1241 obj, data_source, profile, fontsize, figure_size, shading 1242 ) 1243 1244 def annotate_text(self, xpos=0.0, ypos=0.0, text=None, **text_kwargs): 1245 r""" 1246 Allow the user to insert text onto the plot 1247 The x-position and y-position must be given as well as the text string. 1248 Add *text* tp plot at location *xpos*, *ypos* in plot coordinates 1249 (see example below). 1250 1251 Parameters 1252 ---------- 1253 xpos : float 1254 Position on plot in x-coordinates. 1255 ypos : float 1256 Position on plot in y-coordinates. 1257 text : str 1258 The text to insert onto the plot. 1259 **text_kwargs : dict 1260 Extra keyword arguments will be passed to matplotlib text instance 1261 1262 >>> plot.annotate_text(1e-15, 5e4, "Hello YT") 1263 1264 """ 1265 for f in self.data_source._determine_fields(list(self.plots.keys())): 1266 if self.plots[f].figure is not None and text is not None: 1267 self.plots[f].axes.text( 1268 xpos, 1269 ypos, 1270 text, 1271 fontproperties=self._font_properties, 1272 **text_kwargs, 1273 ) 1274 self._plot_text[f] = text 1275 self._text_xpos[f] = xpos 1276 self._text_ypos[f] = ypos 1277 self._text_kwargs[f] = text_kwargs 1278 return self 1279 1280 @validate_plot 1281 def save(self, name=None, suffix=".png", mpl_kwargs=None): 1282 r""" 1283 Saves a 2d profile plot. 1284 1285 Parameters 1286 ---------- 1287 name : str 1288 The output file keyword. 1289 suffix : string 1290 Specify the image type by its suffix. If not specified, the output 1291 type will be inferred from the filename. Defaults to PNG. 1292 mpl_kwargs : dict 1293 A dict of keyword arguments to be passed to matplotlib. 1294 1295 >>> plot.save(mpl_kwargs={"bbox_inches": "tight"}) 1296 1297 """ 1298 names = [] 1299 if not self._plot_valid: 1300 self._setup_plots() 1301 if mpl_kwargs is None: 1302 mpl_kwargs = {} 1303 if name is None: 1304 name = str(self.profile.ds) 1305 name = os.path.expanduser(name) 1306 xfn = self.profile.x_field 1307 yfn = self.profile.y_field 1308 if isinstance(xfn, tuple): 1309 xfn = xfn[1] 1310 if isinstance(yfn, tuple): 1311 yfn = yfn[1] 1312 for f in self.profile.field_data: 1313 _f = f 1314 if isinstance(f, tuple): 1315 _f = _f[1] 1316 middle = f"2d-Profile_{xfn}_{yfn}_{_f}" 1317 splitname = os.path.split(name) 1318 if splitname[0] != "" and not os.path.isdir(splitname[0]): 1319 os.makedirs(splitname[0]) 1320 if os.path.isdir(name) and name != str(self.profile.ds): 1321 name = name + (os.sep if name[-1] != os.sep else "") 1322 name += str(self.profile.ds) 1323 1324 new_name = validate_image_name(name, suffix) 1325 if new_name == name: 1326 for v in self.plots.values(): 1327 out_name = v.save(name, mpl_kwargs) 1328 names.append(out_name) 1329 return names 1330 1331 name = new_name 1332 prefix, suffix = os.path.splitext(name) 1333 name = f"{prefix}_{middle}{suffix}" 1334 1335 names.append(name) 1336 self.plots[f].save(name, mpl_kwargs) 1337 return names 1338 1339 @invalidate_plot 1340 def set_font(self, font_dict=None): 1341 """ 1342 1343 Set the font and font properties. 1344 1345 Parameters 1346 ---------- 1347 1348 font_dict : dict 1349 A dict of keyword parameters to be passed to 1350 :class:`matplotlib.font_manager.FontProperties`. 1351 1352 Possible keys include: 1353 1354 * family - The font family. Can be serif, sans-serif, cursive, 1355 'fantasy', or 'monospace'. 1356 * style - The font style. Either normal, italic or oblique. 1357 * color - A valid color string like 'r', 'g', 'red', 'cobalt', 1358 and 'orange'. 1359 * variant - Either normal or small-caps. 1360 * size - Either a relative value of xx-small, x-small, small, 1361 medium, large, x-large, xx-large or an absolute font size, e.g. 12 1362 * stretch - A numeric value in the range 0-1000 or one of 1363 ultra-condensed, extra-condensed, condensed, semi-condensed, 1364 normal, semi-expanded, expanded, extra-expanded or ultra-expanded 1365 * weight - A numeric value in the range 0-1000 or one of ultralight, 1366 light, normal, regular, book, medium, roman, semibold, demibold, 1367 demi, bold, heavy, extra bold, or black 1368 1369 See the matplotlib font manager API documentation for more details. 1370 https://matplotlib.org/stable/api/font_manager_api.html 1371 1372 Notes 1373 ----- 1374 1375 Mathtext axis labels will only obey the `size` and `color` keyword. 1376 1377 Examples 1378 -------- 1379 1380 This sets the font to be 24-pt, blue, sans-serif, italic, and 1381 bold-face. 1382 1383 >>> prof = ProfilePlot( 1384 ... ds.all_data(), ("gas", "density"), ("gas", "temperature") 1385 ... ) 1386 >>> slc.set_font( 1387 ... { 1388 ... "family": "sans-serif", 1389 ... "style": "italic", 1390 ... "weight": "bold", 1391 ... "size": 24, 1392 ... "color": "blue", 1393 ... } 1394 ... ) 1395 1396 """ 1397 from matplotlib.font_manager import FontProperties 1398 1399 if font_dict is None: 1400 font_dict = {} 1401 if "color" in font_dict: 1402 self._font_color = font_dict.pop("color") 1403 # Set default values if the user does not explicitly set them. 1404 # this prevents reverting to the matplotlib defaults. 1405 font_dict.setdefault("family", "stixgeneral") 1406 font_dict.setdefault("size", 18) 1407 self._font_properties = FontProperties(**font_dict) 1408 return self 1409 1410 @invalidate_plot 1411 def set_title(self, field, title): 1412 """Set a title for the plot. 1413 1414 Parameters 1415 ---------- 1416 field : str 1417 The z field of the plot to add the title. 1418 title : str 1419 The title to add. 1420 1421 Examples 1422 -------- 1423 1424 >>> plot.set_title(("gas", "mass"), "This is a phase plot") 1425 """ 1426 self.plot_title[self.data_source._determine_fields(field)[0]] = title 1427 return self 1428 1429 @invalidate_plot 1430 def annotate_title(self, title): 1431 """Set a title for the plot. 1432 1433 Parameters 1434 ---------- 1435 title : str 1436 The title to add. 1437 1438 Examples 1439 -------- 1440 1441 >>> plot.annotate_title("This is a phase plot") 1442 1443 """ 1444 for f in self._profile.field_data: 1445 if isinstance(f, tuple): 1446 f = f[1] 1447 self.plot_title[self.data_source._determine_fields(f)[0]] = title 1448 return self 1449 1450 @invalidate_plot 1451 def reset_plot(self): 1452 self.plots = {} 1453 return self 1454 1455 @invalidate_plot 1456 def set_log(self, field, log): 1457 """set a field to log or linear. 1458 1459 Parameters 1460 ---------- 1461 field : string 1462 the field to set a transform 1463 log : boolean 1464 Log on/off. 1465 """ 1466 p = self._profile 1467 if field == "all": 1468 self.x_log = log 1469 self.y_log = log 1470 for field in p.field_data: 1471 self.z_log[field] = log 1472 self._profile_valid = False 1473 else: 1474 (field,) = self.profile.data_source._determine_fields([field]) 1475 if field == p.x_field: 1476 self.x_log = log 1477 self._profile_valid = False 1478 elif field == p.y_field: 1479 self.y_log = log 1480 self._profile_valid = False 1481 elif field in p.field_data: 1482 self.z_log[field] = log 1483 else: 1484 raise KeyError(f"Field {field} not in phase plot!") 1485 return self 1486 1487 @invalidate_plot 1488 def set_unit(self, field, unit): 1489 """Sets a new unit for the requested field 1490 1491 Parameters 1492 ---------- 1493 field : string 1494 The name of the field that is to be changed. 1495 1496 unit : string or Unit object 1497 The name of the new unit. 1498 """ 1499 fd = self.data_source._determine_fields(field)[0] 1500 if fd == self.profile.x_field: 1501 self.profile.set_x_unit(unit) 1502 elif fd == self.profile.y_field: 1503 self.profile.set_y_unit(unit) 1504 elif fd in self.profile.field_data.keys(): 1505 self.profile.set_field_unit(field, unit) 1506 self.plots[field].zmin, self.plots[field].zmax = (None, None) 1507 else: 1508 raise KeyError(f"Field {field} not in phase plot!") 1509 return self 1510 1511 @invalidate_plot 1512 @invalidate_profile 1513 def set_xlim(self, xmin=None, xmax=None): 1514 """Sets the limits of the x bin field 1515 1516 Parameters 1517 ---------- 1518 1519 xmin : float or None 1520 The new x minimum in the current x-axis units. Defaults to None, 1521 which leaves the xmin unchanged. 1522 1523 xmax : float or None 1524 The new x maximum in the current x-axis units. Defaults to None, 1525 which leaves the xmax unchanged. 1526 1527 Examples 1528 -------- 1529 1530 >>> import yt 1531 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1532 >>> pp = yt.PhasePlot(ds.all_data(), "density", "temperature", ("gas", "mass")) 1533 >>> pp.set_xlim(1e-29, 1e-24) 1534 >>> pp.save() 1535 1536 """ 1537 p = self._profile 1538 if xmin is None: 1539 xmin = p.x_bins.min() 1540 elif not hasattr(xmin, "units"): 1541 xmin = self.ds.quan(xmin, p.x_bins.units) 1542 if xmax is None: 1543 xmax = p.x_bins.max() 1544 elif not hasattr(xmax, "units"): 1545 xmax = self.ds.quan(xmax, p.x_bins.units) 1546 self._xlim = (xmin, xmax) 1547 return self 1548 1549 @invalidate_plot 1550 @invalidate_profile 1551 def set_ylim(self, ymin=None, ymax=None): 1552 """Sets the plot limits for the y bin field. 1553 1554 Parameters 1555 ---------- 1556 1557 ymin : float or None 1558 The new y minimum in the current y-axis units. Defaults to None, 1559 which leaves the ymin unchanged. 1560 1561 ymax : float or None 1562 The new y maximum in the current y-axis units. Defaults to None, 1563 which leaves the ymax unchanged. 1564 1565 Examples 1566 -------- 1567 1568 >>> import yt 1569 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1570 >>> pp = yt.PhasePlot( 1571 ... ds.all_data(), 1572 ... ("gas", "density"), 1573 ... ("gas", "temperature"), 1574 ... ("gas", "mass"), 1575 ... ) 1576 >>> pp.set_ylim(1e4, 1e6) 1577 >>> pp.save() 1578 1579 """ 1580 p = self._profile 1581 if ymin is None: 1582 ymin = p.y_bins.min() 1583 elif not hasattr(ymin, "units"): 1584 ymin = self.ds.quan(ymin, p.y_bins.units) 1585 if ymax is None: 1586 ymax = p.y_bins.max() 1587 elif not hasattr(ymax, "units"): 1588 ymax = self.ds.quan(ymax, p.y_bins.units) 1589 self._ylim = (ymin, ymax) 1590 return self 1591 1592 def _recreate_profile(self): 1593 p = self._profile 1594 units = {p.x_field: str(p.x.units), p.y_field: str(p.y.units)} 1595 zunits = {field: str(p.field_units[field]) for field in p.field_units} 1596 extrema = {p.x_field: self._xlim, p.y_field: self._ylim} 1597 if self.x_log is not None or self.y_log is not None: 1598 logs = {} 1599 else: 1600 logs = None 1601 if self.x_log is not None: 1602 logs[p.x_field] = self.x_log 1603 if self.y_log is not None: 1604 logs[p.y_field] = self.y_log 1605 deposition = getattr(p, "deposition", None) 1606 additional_kwargs = { 1607 "accumulation": p.accumulation, 1608 "fractional": p.fractional, 1609 "deposition": deposition, 1610 } 1611 self._profile = create_profile( 1612 p.data_source, 1613 [p.x_field, p.y_field], 1614 list(p.field_map.values()), 1615 n_bins=[len(p.x_bins) - 1, len(p.y_bins) - 1], 1616 weight_field=p.weight_field, 1617 units=units, 1618 extrema=extrema, 1619 logs=logs, 1620 **additional_kwargs, 1621 ) 1622 for field in zunits: 1623 self._profile.set_field_unit(field, zunits[field]) 1624 self._profile_valid = True 1625 1626 1627class PhasePlotMPL(ImagePlotMPL): 1628 """A container for a single matplotlib figure and axes for a PhasePlot""" 1629 1630 def __init__( 1631 self, 1632 x_data, 1633 y_data, 1634 data, 1635 x_scale, 1636 y_scale, 1637 z_scale, 1638 cmap, 1639 zlim, 1640 figure_size, 1641 fontsize, 1642 figure, 1643 axes, 1644 cax, 1645 shading="nearest", 1646 ): 1647 self._initfinished = False 1648 self._draw_colorbar = True 1649 self._draw_axes = True 1650 self._figure_size = figure_size 1651 self._shading = shading 1652 # Compute layout 1653 fontscale = float(fontsize) / 18.0 1654 if fontscale < 1.0: 1655 fontscale = np.sqrt(fontscale) 1656 1657 if is_sequence(figure_size): 1658 self._cb_size = 0.0375 * figure_size[0] 1659 else: 1660 self._cb_size = 0.0375 * figure_size 1661 self._ax_text_size = [1.1 * fontscale, 0.9 * fontscale] 1662 self._top_buff_size = 0.30 * fontscale 1663 self._aspect = 1.0 1664 1665 size, axrect, caxrect = self._get_best_layout() 1666 1667 super().__init__(size, axrect, caxrect, zlim, figure, axes, cax) 1668 1669 self._init_image(x_data, y_data, data, x_scale, y_scale, z_scale, zlim, cmap) 1670 1671 self._initfinished = True 1672 1673 def _init_image( 1674 self, x_data, y_data, image_data, x_scale, y_scale, z_scale, zlim, cmap 1675 ): 1676 """Store output of imshow in image variable""" 1677 if z_scale == "log": 1678 norm = matplotlib.colors.LogNorm(zlim[0], zlim[1]) 1679 elif z_scale == "linear": 1680 norm = matplotlib.colors.Normalize(zlim[0], zlim[1]) 1681 self.image = None 1682 self.cb = None 1683 1684 self.image = self.axes.pcolormesh( 1685 np.array(x_data), 1686 np.array(y_data), 1687 np.array(image_data.T), 1688 norm=norm, 1689 cmap=cmap, 1690 shading=self._shading, 1691 ) 1692 1693 self.axes.set_xscale(x_scale) 1694 self.axes.set_yscale(y_scale) 1695 self.cb = self.figure.colorbar(self.image, self.cax) 1696 if z_scale == "linear": 1697 self.cb.formatter.set_scientific(True) 1698 self.cb.formatter.set_powerlimits((-2, 3)) 1699 self.cb.update_ticks() 1700