1from collections import defaultdict 2 3import numpy as np 4 5from yt.funcs import is_sequence, mylog 6from yt.units.unit_object import Unit 7from yt.units.yt_array import YTArray 8from yt.visualization.base_plot_types import PlotMPL 9from yt.visualization.plot_container import ( 10 PlotContainer, 11 PlotDictionary, 12 invalidate_plot, 13 linear_transform, 14 log_transform, 15) 16 17 18class LineBuffer: 19 r""" 20 LineBuffer(ds, start_point, end_point, npoints, label = None) 21 22 This takes a data source and implements a protocol for generating a 23 'pixelized', fixed-resolution line buffer. In other words, LineBuffer 24 takes a starting point, ending point, and number of sampling points and 25 can subsequently generate YTArrays of field values along the sample points. 26 27 Parameters 28 ---------- 29 ds : :class:`yt.data_objects.static_output.Dataset` 30 This is the dataset object holding the data that can be sampled by the 31 LineBuffer 32 start_point : n-element list, tuple, ndarray, or YTArray 33 Contains the coordinates of the first point for constructing the LineBuffer. 34 Must contain n elements where n is the dimensionality of the dataset. 35 end_point : n-element list, tuple, ndarray, or YTArray 36 Contains the coordinates of the first point for constructing the LineBuffer. 37 Must contain n elements where n is the dimensionality of the dataset. 38 npoints : int 39 How many points to sample between start_point and end_point 40 41 Examples 42 -------- 43 >>> lb = yt.LineBuffer(ds, (0.25, 0, 0), (0.25, 1, 0), 100) 44 >>> lb[("all", "u")].max() 45 0.11562424257143075 dimensionless 46 47 """ 48 49 def __init__(self, ds, start_point, end_point, npoints, label=None): 50 self.ds = ds 51 self.start_point = _validate_point(start_point, ds, start=True) 52 self.end_point = _validate_point(end_point, ds) 53 self.npoints = npoints 54 self.label = label 55 self.data = {} 56 57 def keys(self): 58 return self.data.keys() 59 60 def __setitem__(self, item, val): 61 self.data[item] = val 62 63 def __getitem__(self, item): 64 if item in self.data: 65 return self.data[item] 66 mylog.info("Making a line buffer with %d points of %s", self.npoints, item) 67 self.points, self.data[item] = self.ds.coordinates.pixelize_line( 68 item, self.start_point, self.end_point, self.npoints 69 ) 70 71 return self.data[item] 72 73 def __delitem__(self, item): 74 del self.data[item] 75 76 77class LinePlotDictionary(PlotDictionary): 78 def __init__(self, data_source): 79 super().__init__(data_source) 80 self.known_dimensions = {} 81 82 def _sanitize_dimensions(self, item): 83 field = self.data_source._determine_fields(item)[0] 84 finfo = self.data_source.ds.field_info[field] 85 dimensions = Unit( 86 finfo.units, registry=self.data_source.ds.unit_registry 87 ).dimensions 88 if dimensions not in self.known_dimensions: 89 self.known_dimensions[dimensions] = item 90 ret_item = item 91 else: 92 ret_item = self.known_dimensions[dimensions] 93 return ret_item 94 95 def __getitem__(self, item): 96 ret_item = self._sanitize_dimensions(item) 97 return super().__getitem__(ret_item) 98 99 def __setitem__(self, item, value): 100 ret_item = self._sanitize_dimensions(item) 101 super().__setitem__(ret_item, value) 102 103 def __contains__(self, item): 104 ret_item = self._sanitize_dimensions(item) 105 return super().__contains__(ret_item) 106 107 108class LinePlot(PlotContainer): 109 r""" 110 A class for constructing line plots 111 112 Parameters 113 ---------- 114 115 ds : :class:`yt.data_objects.static_output.Dataset` 116 This is the dataset object corresponding to the 117 simulation output to be plotted. 118 fields : string / tuple, or list of strings / tuples 119 The name(s) of the field(s) to be plotted. 120 start_point : n-element list, tuple, ndarray, or YTArray 121 Contains the coordinates of the first point for constructing the line. 122 Must contain n elements where n is the dimensionality of the dataset. 123 end_point : n-element list, tuple, ndarray, or YTArray 124 Contains the coordinates of the first point for constructing the line. 125 Must contain n elements where n is the dimensionality of the dataset. 126 npoints : int 127 How many points to sample between start_point and end_point for 128 constructing the line plot 129 figure_size : int or two-element iterable of ints 130 Size in inches of the image. 131 Default: 5 (5x5) 132 fontsize : int 133 Font size for all text in the plot. 134 Default: 14 135 field_labels : dictionary 136 Keys should be the field names. Values should be latex-formattable 137 strings used in the LinePlot legend 138 Default: None 139 140 141 Example 142 ------- 143 144 >>> import yt 145 146 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 147 148 >>> plot = yt.LinePlot(ds, "density", [0, 0, 0], [1, 1, 1], 512) 149 >>> plot.add_legend("density") 150 >>> plot.set_x_unit("cm") 151 >>> plot.set_unit("density", "kg/cm**3") 152 >>> plot.save() 153 154 """ 155 _plot_type = "line_plot" 156 157 def __init__( 158 self, 159 ds, 160 fields, 161 start_point, 162 end_point, 163 npoints, 164 figure_size=5, 165 fontsize=14, 166 field_labels=None, 167 ): 168 """ 169 Sets up figure and axes 170 """ 171 line = LineBuffer(ds, start_point, end_point, npoints, label=None) 172 self.lines = [line] 173 self._initialize_instance(self, ds, fields, figure_size, fontsize, field_labels) 174 self._setup_plots() 175 176 @classmethod 177 def _initialize_instance( 178 cls, obj, ds, fields, figure_size=5, fontsize=14, field_labels=None 179 ): 180 obj._x_unit = None 181 obj._y_units = {} 182 obj._titles = {} 183 184 data_source = ds.all_data() 185 186 obj.fields = data_source._determine_fields(fields) 187 obj.plots = LinePlotDictionary(data_source) 188 obj.include_legend = defaultdict(bool) 189 super(LinePlot, obj).__init__(data_source, figure_size, fontsize) 190 for f in obj.fields: 191 finfo = obj.data_source.ds._get_field_info(*f) 192 if finfo.take_log: 193 obj._field_transform[f] = log_transform 194 else: 195 obj._field_transform[f] = linear_transform 196 197 if field_labels is None: 198 obj.field_labels = {} 199 else: 200 obj.field_labels = field_labels 201 for f in obj.fields: 202 if f not in obj.field_labels: 203 obj.field_labels[f] = f[1] 204 205 @classmethod 206 def from_lines( 207 cls, ds, fields, lines, figure_size=5, font_size=14, field_labels=None 208 ): 209 """ 210 A class method for constructing a line plot from multiple sampling lines 211 212 Parameters 213 ---------- 214 215 ds : :class:`yt.data_objects.static_output.Dataset` 216 This is the dataset object corresponding to the 217 simulation output to be plotted. 218 fields : field name or list of field names 219 The name(s) of the field(s) to be plotted. 220 lines : list of :class:`yt.visualization.line_plot.LineBuffer` instances 221 The lines from which to sample data 222 figure_size : int or two-element iterable of ints 223 Size in inches of the image. 224 Default: 5 (5x5) 225 font_size : int 226 Font size for all text in the plot. 227 Default: 14 228 field_labels : dictionary 229 Keys should be the field names. Values should be latex-formattable 230 strings used in the LinePlot legend 231 Default: None 232 233 Example 234 -------- 235 >>> ds = yt.load( 236 ... "SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e", step=-1 237 ... ) 238 >>> fields = [field for field in ds.field_list if field[0] == "all"] 239 >>> lines = [ 240 ... yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label="x = 0.25"), 241 ... yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label="x = 0.5"), 242 ... ] 243 >>> lines.append() 244 245 >>> plot = yt.LinePlot.from_lines(ds, fields, lines) 246 >>> plot.save() 247 248 """ 249 obj = cls.__new__(cls) 250 obj.lines = lines 251 cls._initialize_instance(obj, ds, fields, figure_size, font_size, field_labels) 252 obj._setup_plots() 253 return obj 254 255 def _get_plot_instance(self, field): 256 fontscale = self._font_properties._size / 14.0 257 top_buff_size = 0.35 * fontscale 258 259 x_axis_size = 1.35 * fontscale 260 y_axis_size = 0.7 * fontscale 261 right_buff_size = 0.2 * fontscale 262 263 if is_sequence(self.figure_size): 264 figure_size = self.figure_size 265 else: 266 figure_size = (self.figure_size, self.figure_size) 267 268 xbins = np.array([x_axis_size, figure_size[0], right_buff_size]) 269 ybins = np.array([y_axis_size, figure_size[1], top_buff_size]) 270 271 size = [xbins.sum(), ybins.sum()] 272 273 x_frac_widths = xbins / size[0] 274 y_frac_widths = ybins / size[1] 275 276 axrect = ( 277 x_frac_widths[0], 278 y_frac_widths[0], 279 x_frac_widths[1], 280 y_frac_widths[1], 281 ) 282 283 try: 284 plot = self.plots[field] 285 except KeyError: 286 plot = PlotMPL(self.figure_size, axrect, None, None) 287 self.plots[field] = plot 288 return plot 289 290 def _setup_plots(self): 291 if self._plot_valid: 292 return 293 for plot in self.plots.values(): 294 plot.axes.cla() 295 for line in self.lines: 296 dimensions_counter = defaultdict(int) 297 for field in self.fields: 298 finfo = self.ds.field_info[field] 299 dimensions = Unit( 300 finfo.units, registry=self.ds.unit_registry 301 ).dimensions 302 dimensions_counter[dimensions] += 1 303 for field in self.fields: 304 # get plot instance 305 plot = self._get_plot_instance(field) 306 307 # calculate x and y 308 x, y = self.ds.coordinates.pixelize_line( 309 field, line.start_point, line.end_point, line.npoints 310 ) 311 312 # scale x and y to proper units 313 if self._x_unit is None: 314 unit_x = x.units 315 else: 316 unit_x = self._x_unit 317 318 if field in self._y_units: 319 unit_y = self._y_units[field] 320 else: 321 unit_y = y.units 322 323 x = x.to(unit_x) 324 y = y.to(unit_y) 325 326 # determine legend label 327 str_seq = [] 328 str_seq.append(line.label) 329 str_seq.append(self.field_labels[field]) 330 delim = "; " 331 legend_label = delim.join(filter(None, str_seq)) 332 333 # apply plot to matplotlib axes 334 plot.axes.plot(x, y, label=legend_label) 335 336 # apply log transforms if requested 337 if self._field_transform[field] != linear_transform: 338 if (y < 0).any(): 339 plot.axes.set_yscale("symlog") 340 else: 341 plot.axes.set_yscale("log") 342 343 # set font properties 344 plot._set_font_properties(self._font_properties, None) 345 346 # set x and y axis labels 347 axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y) 348 349 if self._xlabel is not None: 350 x_label = self._xlabel 351 else: 352 x_label = r"$\rm{Path\ Length" + axes_unit_labels[0] + "}$" 353 354 if self._ylabel is not None: 355 y_label = self._ylabel 356 else: 357 finfo = self.ds.field_info[field] 358 dimensions = Unit( 359 finfo.units, registry=self.ds.unit_registry 360 ).dimensions 361 if dimensions_counter[dimensions] > 1: 362 y_label = ( 363 r"$\rm{Multiple\ Fields}$" 364 + r"$\rm{" 365 + axes_unit_labels[1] 366 + "}$" 367 ) 368 else: 369 y_label = ( 370 finfo.get_latex_display_name() 371 + r"$\rm{" 372 + axes_unit_labels[1] 373 + "}$" 374 ) 375 376 plot.axes.set_xlabel(x_label) 377 plot.axes.set_ylabel(y_label) 378 379 # apply title 380 if field in self._titles: 381 plot.axes.set_title(self._titles[field]) 382 383 # apply legend 384 dim_field = self.plots._sanitize_dimensions(field) 385 if self.include_legend[dim_field]: 386 plot.axes.legend() 387 388 self._plot_valid = True 389 390 @invalidate_plot 391 def annotate_legend(self, field): 392 """ 393 Adds a legend to the `LinePlot` instance. The `_sanitize_dimensions` 394 call ensures that a legend label will be added for every field of 395 a multi-field plot 396 """ 397 dim_field = self.plots._sanitize_dimensions(field) 398 self.include_legend[dim_field] = True 399 400 @invalidate_plot 401 def set_x_unit(self, unit_name): 402 """Set the unit to use along the x-axis 403 404 Parameters 405 ---------- 406 unit_name: str 407 The name of the unit to use for the x-axis unit 408 """ 409 self._x_unit = unit_name 410 411 @invalidate_plot 412 def set_unit(self, field, unit_name): 413 """Set the unit used to plot the field 414 415 Parameters 416 ---------- 417 field: str or field tuple 418 The name of the field to set the units for 419 unit_name: str 420 The name of the unit to use for this field 421 """ 422 self._y_units[self.data_source._determine_fields(field)[0]] = unit_name 423 424 @invalidate_plot 425 def annotate_title(self, field, title): 426 """Set the unit used to plot the field 427 428 Parameters 429 ---------- 430 field: str or field tuple 431 The name of the field to set the units for 432 title: str 433 The title to use for the plot 434 """ 435 self._titles[self.data_source._determine_fields(field)[0]] = title 436 437 438def _validate_point(point, ds, start=False): 439 if not is_sequence(point): 440 raise RuntimeError("Input point must be array-like") 441 if not isinstance(point, YTArray): 442 point = ds.arr(point, "code_length", dtype=np.float64) 443 if len(point.shape) != 1: 444 raise RuntimeError("Input point must be a 1D array") 445 if point.shape[0] < ds.dimensionality: 446 raise RuntimeError("Input point must have an element for each dimension") 447 # need to pad to 3D elements to avoid issues later 448 if point.shape[0] < 3: 449 if start: 450 val = 0 451 else: 452 val = 1 453 point = np.append(point.d, [val] * (3 - ds.dimensionality)) * point.uq 454 return point 455