1from collections import defaultdict
2
3import numpy as np
4
5from yt.funcs import is_sequence, mylog
6from yt.units.unit_object import Unit
7from yt.units.yt_array import YTArray
8from yt.visualization.base_plot_types import PlotMPL
9from yt.visualization.plot_container import (
10    PlotContainer,
11    PlotDictionary,
12    invalidate_plot,
13    linear_transform,
14    log_transform,
15)
16
17
18class LineBuffer:
19    r"""
20    LineBuffer(ds, start_point, end_point, npoints, label = None)
21
22    This takes a data source and implements a protocol for generating a
23    'pixelized', fixed-resolution line buffer. In other words, LineBuffer
24    takes a starting point, ending point, and number of sampling points and
25    can subsequently generate YTArrays of field values along the sample points.
26
27    Parameters
28    ----------
29    ds : :class:`yt.data_objects.static_output.Dataset`
30        This is the dataset object holding the data that can be sampled by the
31        LineBuffer
32    start_point : n-element list, tuple, ndarray, or YTArray
33        Contains the coordinates of the first point for constructing the LineBuffer.
34        Must contain n elements where n is the dimensionality of the dataset.
35    end_point : n-element list, tuple, ndarray, or YTArray
36        Contains the coordinates of the first point for constructing the LineBuffer.
37        Must contain n elements where n is the dimensionality of the dataset.
38    npoints : int
39        How many points to sample between start_point and end_point
40
41    Examples
42    --------
43    >>> lb = yt.LineBuffer(ds, (0.25, 0, 0), (0.25, 1, 0), 100)
44    >>> lb[("all", "u")].max()
45    0.11562424257143075 dimensionless
46
47    """
48
49    def __init__(self, ds, start_point, end_point, npoints, label=None):
50        self.ds = ds
51        self.start_point = _validate_point(start_point, ds, start=True)
52        self.end_point = _validate_point(end_point, ds)
53        self.npoints = npoints
54        self.label = label
55        self.data = {}
56
57    def keys(self):
58        return self.data.keys()
59
60    def __setitem__(self, item, val):
61        self.data[item] = val
62
63    def __getitem__(self, item):
64        if item in self.data:
65            return self.data[item]
66        mylog.info("Making a line buffer with %d points of %s", self.npoints, item)
67        self.points, self.data[item] = self.ds.coordinates.pixelize_line(
68            item, self.start_point, self.end_point, self.npoints
69        )
70
71        return self.data[item]
72
73    def __delitem__(self, item):
74        del self.data[item]
75
76
77class LinePlotDictionary(PlotDictionary):
78    def __init__(self, data_source):
79        super().__init__(data_source)
80        self.known_dimensions = {}
81
82    def _sanitize_dimensions(self, item):
83        field = self.data_source._determine_fields(item)[0]
84        finfo = self.data_source.ds.field_info[field]
85        dimensions = Unit(
86            finfo.units, registry=self.data_source.ds.unit_registry
87        ).dimensions
88        if dimensions not in self.known_dimensions:
89            self.known_dimensions[dimensions] = item
90            ret_item = item
91        else:
92            ret_item = self.known_dimensions[dimensions]
93        return ret_item
94
95    def __getitem__(self, item):
96        ret_item = self._sanitize_dimensions(item)
97        return super().__getitem__(ret_item)
98
99    def __setitem__(self, item, value):
100        ret_item = self._sanitize_dimensions(item)
101        super().__setitem__(ret_item, value)
102
103    def __contains__(self, item):
104        ret_item = self._sanitize_dimensions(item)
105        return super().__contains__(ret_item)
106
107
108class LinePlot(PlotContainer):
109    r"""
110    A class for constructing line plots
111
112    Parameters
113    ----------
114
115    ds : :class:`yt.data_objects.static_output.Dataset`
116        This is the dataset object corresponding to the
117        simulation output to be plotted.
118    fields : string / tuple, or list of strings / tuples
119        The name(s) of the field(s) to be plotted.
120    start_point : n-element list, tuple, ndarray, or YTArray
121        Contains the coordinates of the first point for constructing the line.
122        Must contain n elements where n is the dimensionality of the dataset.
123    end_point : n-element list, tuple, ndarray, or YTArray
124        Contains the coordinates of the first point for constructing the line.
125        Must contain n elements where n is the dimensionality of the dataset.
126    npoints : int
127        How many points to sample between start_point and end_point for
128        constructing the line plot
129    figure_size : int or two-element iterable of ints
130        Size in inches of the image.
131        Default: 5 (5x5)
132    fontsize : int
133        Font size for all text in the plot.
134        Default: 14
135    field_labels : dictionary
136        Keys should be the field names. Values should be latex-formattable
137        strings used in the LinePlot legend
138        Default: None
139
140
141    Example
142    -------
143
144    >>> import yt
145
146    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
147
148    >>> plot = yt.LinePlot(ds, "density", [0, 0, 0], [1, 1, 1], 512)
149    >>> plot.add_legend("density")
150    >>> plot.set_x_unit("cm")
151    >>> plot.set_unit("density", "kg/cm**3")
152    >>> plot.save()
153
154    """
155    _plot_type = "line_plot"
156
157    def __init__(
158        self,
159        ds,
160        fields,
161        start_point,
162        end_point,
163        npoints,
164        figure_size=5,
165        fontsize=14,
166        field_labels=None,
167    ):
168        """
169        Sets up figure and axes
170        """
171        line = LineBuffer(ds, start_point, end_point, npoints, label=None)
172        self.lines = [line]
173        self._initialize_instance(self, ds, fields, figure_size, fontsize, field_labels)
174        self._setup_plots()
175
176    @classmethod
177    def _initialize_instance(
178        cls, obj, ds, fields, figure_size=5, fontsize=14, field_labels=None
179    ):
180        obj._x_unit = None
181        obj._y_units = {}
182        obj._titles = {}
183
184        data_source = ds.all_data()
185
186        obj.fields = data_source._determine_fields(fields)
187        obj.plots = LinePlotDictionary(data_source)
188        obj.include_legend = defaultdict(bool)
189        super(LinePlot, obj).__init__(data_source, figure_size, fontsize)
190        for f in obj.fields:
191            finfo = obj.data_source.ds._get_field_info(*f)
192            if finfo.take_log:
193                obj._field_transform[f] = log_transform
194            else:
195                obj._field_transform[f] = linear_transform
196
197        if field_labels is None:
198            obj.field_labels = {}
199        else:
200            obj.field_labels = field_labels
201        for f in obj.fields:
202            if f not in obj.field_labels:
203                obj.field_labels[f] = f[1]
204
205    @classmethod
206    def from_lines(
207        cls, ds, fields, lines, figure_size=5, font_size=14, field_labels=None
208    ):
209        """
210        A class method for constructing a line plot from multiple sampling lines
211
212        Parameters
213        ----------
214
215        ds : :class:`yt.data_objects.static_output.Dataset`
216            This is the dataset object corresponding to the
217            simulation output to be plotted.
218        fields : field name or list of field names
219            The name(s) of the field(s) to be plotted.
220        lines : list of :class:`yt.visualization.line_plot.LineBuffer` instances
221            The lines from which to sample data
222        figure_size : int or two-element iterable of ints
223            Size in inches of the image.
224            Default: 5 (5x5)
225        font_size : int
226            Font size for all text in the plot.
227            Default: 14
228        field_labels : dictionary
229            Keys should be the field names. Values should be latex-formattable
230            strings used in the LinePlot legend
231            Default: None
232
233        Example
234        --------
235        >>> ds = yt.load(
236        ...     "SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e", step=-1
237        ... )
238        >>> fields = [field for field in ds.field_list if field[0] == "all"]
239        >>> lines = [
240        ...     yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label="x = 0.25"),
241        ...     yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label="x = 0.5"),
242        ... ]
243        >>> lines.append()
244
245        >>> plot = yt.LinePlot.from_lines(ds, fields, lines)
246        >>> plot.save()
247
248        """
249        obj = cls.__new__(cls)
250        obj.lines = lines
251        cls._initialize_instance(obj, ds, fields, figure_size, font_size, field_labels)
252        obj._setup_plots()
253        return obj
254
255    def _get_plot_instance(self, field):
256        fontscale = self._font_properties._size / 14.0
257        top_buff_size = 0.35 * fontscale
258
259        x_axis_size = 1.35 * fontscale
260        y_axis_size = 0.7 * fontscale
261        right_buff_size = 0.2 * fontscale
262
263        if is_sequence(self.figure_size):
264            figure_size = self.figure_size
265        else:
266            figure_size = (self.figure_size, self.figure_size)
267
268        xbins = np.array([x_axis_size, figure_size[0], right_buff_size])
269        ybins = np.array([y_axis_size, figure_size[1], top_buff_size])
270
271        size = [xbins.sum(), ybins.sum()]
272
273        x_frac_widths = xbins / size[0]
274        y_frac_widths = ybins / size[1]
275
276        axrect = (
277            x_frac_widths[0],
278            y_frac_widths[0],
279            x_frac_widths[1],
280            y_frac_widths[1],
281        )
282
283        try:
284            plot = self.plots[field]
285        except KeyError:
286            plot = PlotMPL(self.figure_size, axrect, None, None)
287            self.plots[field] = plot
288        return plot
289
290    def _setup_plots(self):
291        if self._plot_valid:
292            return
293        for plot in self.plots.values():
294            plot.axes.cla()
295        for line in self.lines:
296            dimensions_counter = defaultdict(int)
297            for field in self.fields:
298                finfo = self.ds.field_info[field]
299                dimensions = Unit(
300                    finfo.units, registry=self.ds.unit_registry
301                ).dimensions
302                dimensions_counter[dimensions] += 1
303            for field in self.fields:
304                # get plot instance
305                plot = self._get_plot_instance(field)
306
307                # calculate x and y
308                x, y = self.ds.coordinates.pixelize_line(
309                    field, line.start_point, line.end_point, line.npoints
310                )
311
312                # scale x and y to proper units
313                if self._x_unit is None:
314                    unit_x = x.units
315                else:
316                    unit_x = self._x_unit
317
318                if field in self._y_units:
319                    unit_y = self._y_units[field]
320                else:
321                    unit_y = y.units
322
323                x = x.to(unit_x)
324                y = y.to(unit_y)
325
326                # determine legend label
327                str_seq = []
328                str_seq.append(line.label)
329                str_seq.append(self.field_labels[field])
330                delim = "; "
331                legend_label = delim.join(filter(None, str_seq))
332
333                # apply plot to matplotlib axes
334                plot.axes.plot(x, y, label=legend_label)
335
336                # apply log transforms if requested
337                if self._field_transform[field] != linear_transform:
338                    if (y < 0).any():
339                        plot.axes.set_yscale("symlog")
340                    else:
341                        plot.axes.set_yscale("log")
342
343                # set font properties
344                plot._set_font_properties(self._font_properties, None)
345
346                # set x and y axis labels
347                axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
348
349                if self._xlabel is not None:
350                    x_label = self._xlabel
351                else:
352                    x_label = r"$\rm{Path\ Length" + axes_unit_labels[0] + "}$"
353
354                if self._ylabel is not None:
355                    y_label = self._ylabel
356                else:
357                    finfo = self.ds.field_info[field]
358                    dimensions = Unit(
359                        finfo.units, registry=self.ds.unit_registry
360                    ).dimensions
361                    if dimensions_counter[dimensions] > 1:
362                        y_label = (
363                            r"$\rm{Multiple\ Fields}$"
364                            + r"$\rm{"
365                            + axes_unit_labels[1]
366                            + "}$"
367                        )
368                    else:
369                        y_label = (
370                            finfo.get_latex_display_name()
371                            + r"$\rm{"
372                            + axes_unit_labels[1]
373                            + "}$"
374                        )
375
376                plot.axes.set_xlabel(x_label)
377                plot.axes.set_ylabel(y_label)
378
379                # apply title
380                if field in self._titles:
381                    plot.axes.set_title(self._titles[field])
382
383                # apply legend
384                dim_field = self.plots._sanitize_dimensions(field)
385                if self.include_legend[dim_field]:
386                    plot.axes.legend()
387
388        self._plot_valid = True
389
390    @invalidate_plot
391    def annotate_legend(self, field):
392        """
393        Adds a legend to the `LinePlot` instance. The `_sanitize_dimensions`
394        call ensures that a legend label will be added for every field of
395        a multi-field plot
396        """
397        dim_field = self.plots._sanitize_dimensions(field)
398        self.include_legend[dim_field] = True
399
400    @invalidate_plot
401    def set_x_unit(self, unit_name):
402        """Set the unit to use along the x-axis
403
404        Parameters
405        ----------
406        unit_name: str
407          The name of the unit to use for the x-axis unit
408        """
409        self._x_unit = unit_name
410
411    @invalidate_plot
412    def set_unit(self, field, unit_name):
413        """Set the unit used to plot the field
414
415        Parameters
416        ----------
417        field: str or field tuple
418           The name of the field to set the units for
419        unit_name: str
420           The name of the unit to use for this field
421        """
422        self._y_units[self.data_source._determine_fields(field)[0]] = unit_name
423
424    @invalidate_plot
425    def annotate_title(self, field, title):
426        """Set the unit used to plot the field
427
428        Parameters
429        ----------
430        field: str or field tuple
431           The name of the field to set the units for
432        title: str
433           The title to use for the plot
434        """
435        self._titles[self.data_source._determine_fields(field)[0]] = title
436
437
438def _validate_point(point, ds, start=False):
439    if not is_sequence(point):
440        raise RuntimeError("Input point must be array-like")
441    if not isinstance(point, YTArray):
442        point = ds.arr(point, "code_length", dtype=np.float64)
443    if len(point.shape) != 1:
444        raise RuntimeError("Input point must be a 1D array")
445    if point.shape[0] < ds.dimensionality:
446        raise RuntimeError("Input point must have an element for each dimension")
447    # need to pad to 3D elements to avoid issues later
448    if point.shape[0] < 3:
449        if start:
450            val = 0
451        else:
452            val = 1
453        point = np.append(point.d, [val] * (3 - ds.dimensionality)) * point.uq
454    return point
455