1from io import BytesIO
2
3import numpy as np
4
5from yt.data_objects.profiles import create_profile
6from yt.funcs import mylog
7from yt.visualization.volume_rendering.transfer_functions import ColorTransferFunction
8
9
10class TransferFunctionHelper:
11    r"""A transfer function helper.
12
13    This attempts to help set up a good transfer function by finding
14    bounds, handling linear/log options, and displaying the transfer
15    function combined with 1D profiles of rendering quantity.
16
17    Parameters
18    ----------
19    ds: A Dataset instance
20        A static output that is currently being rendered. This is used to
21        help set up data bounds.
22
23    Notes
24    -----
25    """
26
27    profiles = None
28
29    def __init__(self, ds):
30        self.ds = ds
31        self.field = None
32        self.log = False
33        self.tf = None
34        self.bounds = None
35        self.grey_opacity = False
36        self.profiles = {}
37
38    def set_bounds(self, bounds=None):
39        """
40        Set the bounds of the transfer function.
41
42        Parameters
43        ----------
44        bounds: array-like, length 2, optional
45            A length 2 list/array in the form [min, max]. These should be the
46            raw values and not the logarithm of the min and max. If bounds is
47            None, the bounds of the data are calculated from all of the data
48            in the dataset.  This can be slow for very large datasets.
49        """
50        if bounds is None:
51            bounds = self.ds.all_data().quantities["Extrema"](self.field, non_zero=True)
52            bounds = [b.ndarray_view() for b in bounds]
53        self.bounds = bounds
54
55        # Do some error checking.
56        assert len(self.bounds) == 2
57        if self.log:
58            assert self.bounds[0] > 0.0
59            assert self.bounds[1] > 0.0
60        return
61
62    def set_field(self, field):
63        """
64        Set the field to be rendered
65
66        Parameters
67        ----------
68        field: string
69            The field to be rendered.
70        """
71        if field != self.field:
72            self.log = self.ds._get_field_info(field).take_log
73        self.field = field
74
75    def set_log(self, log):
76        """
77        Set whether or not the transfer function should be in log or linear
78        space. Also modifies the ds.field_info[field].take_log attribute to
79        stay in sync with this setting.
80
81        Parameters
82        ----------
83        log: boolean
84            Sets whether the transfer function should use log or linear space.
85        """
86        self.log = log
87
88    def build_transfer_function(self):
89        """
90        Builds the transfer function according to the current state of the
91        TransferFunctionHelper.
92
93
94        Returns
95        -------
96
97        A ColorTransferFunction object.
98
99        """
100        if self.bounds is None:
101            mylog.info(
102                "Calculating data bounds. This may take a while. "
103                "Set the TransferFunctionHelper.bounds to avoid this."
104            )
105            self.set_bounds()
106
107        if self.log:
108            mi, ma = np.log10(self.bounds[0]), np.log10(self.bounds[1])
109        else:
110            mi, ma = self.bounds
111
112        self.tf = ColorTransferFunction(
113            (mi, ma), grey_opacity=self.grey_opacity, nbins=512
114        )
115        return self.tf
116
117    def setup_default(self):
118        """Setup a default colormap
119
120        Creates a ColorTransferFunction including 10 gaussian layers whose
121        colors sample the 'spectral' colormap. Also attempts to scale the
122        transfer function to produce a natural contrast ratio.
123
124        """
125        self.tf.add_layers(10, colormap="nipy_spectral")
126        factor = self.tf.funcs[-1].y.size / self.tf.funcs[-1].y.sum()
127        self.tf.funcs[-1].y *= 2 * factor
128
129    def plot(self, fn=None, profile_field=None, profile_weight=None):
130        """
131        Save the current transfer function to a bitmap, or display
132        it inline.
133
134        Parameters
135        ----------
136        fn: string, optional
137            Filename to save the image to. If None, the returns an image
138            to an IPython session.
139
140        Returns
141        -------
142
143        If fn is None, will return an image to an IPython notebook.
144
145        """
146        from matplotlib.figure import Figure
147
148        from yt.visualization._mpl_imports import FigureCanvasAgg
149
150        if self.tf is None:
151            self.build_transfer_function()
152            self.setup_default()
153        tf = self.tf
154        if self.log:
155            xfunc = np.logspace
156            xmi, xma = np.log10(self.bounds[0]), np.log10(self.bounds[1])
157        else:
158            xfunc = np.linspace
159            # Need to strip units off of the bounds to avoid a recursion error
160            # in matplotlib 1.3.1
161            xmi, xma = (np.float64(b) for b in self.bounds)
162
163        x = xfunc(xmi, xma, tf.nbins)
164        y = tf.funcs[3].y
165        w = np.append(x[1:] - x[:-1], x[-1] - x[-2])
166        colors = np.array(
167            [tf.funcs[0].y, tf.funcs[1].y, tf.funcs[2].y, np.ones_like(x)]
168        ).T
169
170        fig = Figure(figsize=[6, 3])
171        canvas = FigureCanvasAgg(fig)
172        ax = fig.add_axes([0.2, 0.2, 0.75, 0.75])
173        ax.bar(
174            x,
175            tf.funcs[3].y,
176            w,
177            edgecolor=[0.0, 0.0, 0.0, 0.0],
178            log=self.log,
179            color=colors,
180            bottom=[0],
181        )
182
183        if profile_field is not None:
184            try:
185                prof = self.profiles[self.field]
186            except KeyError:
187                self.setup_profile(profile_field, profile_weight)
188                prof = self.profiles[self.field]
189            try:
190                prof[profile_field]
191            except KeyError:
192                prof.add_fields([profile_field])
193            # Strip units, if any, for matplotlib 1.3.1
194            xplot = np.array(prof.x)
195            yplot = np.array(
196                prof[profile_field] * tf.funcs[3].y.max() / prof[profile_field].max()
197            )
198            ax.plot(xplot, yplot, color="w", linewidth=3)
199            ax.plot(xplot, yplot, color="k")
200
201        ax.set_xscale({True: "log", False: "linear"}[self.log])
202        ax.set_xlim(x.min(), x.max())
203        ax.set_xlabel(self.ds._get_field_info(self.field).get_label())
204        ax.set_ylabel(r"$\mathrm{alpha}$")
205        ax.set_ylim(y.max() * 1.0e-3, y.max() * 2)
206
207        if fn is None:
208            from IPython.core.display import Image
209
210            f = BytesIO()
211            canvas.print_figure(f)
212            f.seek(0)
213            img = f.read()
214            return Image(img)
215        else:
216            fig.savefig(fn)
217
218    def setup_profile(self, profile_field=None, profile_weight=None):
219        if profile_field is None:
220            profile_field = "cell_volume"
221        prof = create_profile(
222            self.ds.all_data(),
223            self.field,
224            profile_field,
225            n_bins=128,
226            extrema={self.field: self.bounds},
227            weight_field=profile_weight,
228            logs={self.field: self.log},
229        )
230        self.profiles[self.field] = prof
231        return
232