1from io import BytesIO 2 3import numpy as np 4 5from yt.data_objects.profiles import create_profile 6from yt.funcs import mylog 7from yt.visualization.volume_rendering.transfer_functions import ColorTransferFunction 8 9 10class TransferFunctionHelper: 11 r"""A transfer function helper. 12 13 This attempts to help set up a good transfer function by finding 14 bounds, handling linear/log options, and displaying the transfer 15 function combined with 1D profiles of rendering quantity. 16 17 Parameters 18 ---------- 19 ds: A Dataset instance 20 A static output that is currently being rendered. This is used to 21 help set up data bounds. 22 23 Notes 24 ----- 25 """ 26 27 profiles = None 28 29 def __init__(self, ds): 30 self.ds = ds 31 self.field = None 32 self.log = False 33 self.tf = None 34 self.bounds = None 35 self.grey_opacity = False 36 self.profiles = {} 37 38 def set_bounds(self, bounds=None): 39 """ 40 Set the bounds of the transfer function. 41 42 Parameters 43 ---------- 44 bounds: array-like, length 2, optional 45 A length 2 list/array in the form [min, max]. These should be the 46 raw values and not the logarithm of the min and max. If bounds is 47 None, the bounds of the data are calculated from all of the data 48 in the dataset. This can be slow for very large datasets. 49 """ 50 if bounds is None: 51 bounds = self.ds.all_data().quantities["Extrema"](self.field, non_zero=True) 52 bounds = [b.ndarray_view() for b in bounds] 53 self.bounds = bounds 54 55 # Do some error checking. 56 assert len(self.bounds) == 2 57 if self.log: 58 assert self.bounds[0] > 0.0 59 assert self.bounds[1] > 0.0 60 return 61 62 def set_field(self, field): 63 """ 64 Set the field to be rendered 65 66 Parameters 67 ---------- 68 field: string 69 The field to be rendered. 70 """ 71 if field != self.field: 72 self.log = self.ds._get_field_info(field).take_log 73 self.field = field 74 75 def set_log(self, log): 76 """ 77 Set whether or not the transfer function should be in log or linear 78 space. Also modifies the ds.field_info[field].take_log attribute to 79 stay in sync with this setting. 80 81 Parameters 82 ---------- 83 log: boolean 84 Sets whether the transfer function should use log or linear space. 85 """ 86 self.log = log 87 88 def build_transfer_function(self): 89 """ 90 Builds the transfer function according to the current state of the 91 TransferFunctionHelper. 92 93 94 Returns 95 ------- 96 97 A ColorTransferFunction object. 98 99 """ 100 if self.bounds is None: 101 mylog.info( 102 "Calculating data bounds. This may take a while. " 103 "Set the TransferFunctionHelper.bounds to avoid this." 104 ) 105 self.set_bounds() 106 107 if self.log: 108 mi, ma = np.log10(self.bounds[0]), np.log10(self.bounds[1]) 109 else: 110 mi, ma = self.bounds 111 112 self.tf = ColorTransferFunction( 113 (mi, ma), grey_opacity=self.grey_opacity, nbins=512 114 ) 115 return self.tf 116 117 def setup_default(self): 118 """Setup a default colormap 119 120 Creates a ColorTransferFunction including 10 gaussian layers whose 121 colors sample the 'spectral' colormap. Also attempts to scale the 122 transfer function to produce a natural contrast ratio. 123 124 """ 125 self.tf.add_layers(10, colormap="nipy_spectral") 126 factor = self.tf.funcs[-1].y.size / self.tf.funcs[-1].y.sum() 127 self.tf.funcs[-1].y *= 2 * factor 128 129 def plot(self, fn=None, profile_field=None, profile_weight=None): 130 """ 131 Save the current transfer function to a bitmap, or display 132 it inline. 133 134 Parameters 135 ---------- 136 fn: string, optional 137 Filename to save the image to. If None, the returns an image 138 to an IPython session. 139 140 Returns 141 ------- 142 143 If fn is None, will return an image to an IPython notebook. 144 145 """ 146 from matplotlib.figure import Figure 147 148 from yt.visualization._mpl_imports import FigureCanvasAgg 149 150 if self.tf is None: 151 self.build_transfer_function() 152 self.setup_default() 153 tf = self.tf 154 if self.log: 155 xfunc = np.logspace 156 xmi, xma = np.log10(self.bounds[0]), np.log10(self.bounds[1]) 157 else: 158 xfunc = np.linspace 159 # Need to strip units off of the bounds to avoid a recursion error 160 # in matplotlib 1.3.1 161 xmi, xma = (np.float64(b) for b in self.bounds) 162 163 x = xfunc(xmi, xma, tf.nbins) 164 y = tf.funcs[3].y 165 w = np.append(x[1:] - x[:-1], x[-1] - x[-2]) 166 colors = np.array( 167 [tf.funcs[0].y, tf.funcs[1].y, tf.funcs[2].y, np.ones_like(x)] 168 ).T 169 170 fig = Figure(figsize=[6, 3]) 171 canvas = FigureCanvasAgg(fig) 172 ax = fig.add_axes([0.2, 0.2, 0.75, 0.75]) 173 ax.bar( 174 x, 175 tf.funcs[3].y, 176 w, 177 edgecolor=[0.0, 0.0, 0.0, 0.0], 178 log=self.log, 179 color=colors, 180 bottom=[0], 181 ) 182 183 if profile_field is not None: 184 try: 185 prof = self.profiles[self.field] 186 except KeyError: 187 self.setup_profile(profile_field, profile_weight) 188 prof = self.profiles[self.field] 189 try: 190 prof[profile_field] 191 except KeyError: 192 prof.add_fields([profile_field]) 193 # Strip units, if any, for matplotlib 1.3.1 194 xplot = np.array(prof.x) 195 yplot = np.array( 196 prof[profile_field] * tf.funcs[3].y.max() / prof[profile_field].max() 197 ) 198 ax.plot(xplot, yplot, color="w", linewidth=3) 199 ax.plot(xplot, yplot, color="k") 200 201 ax.set_xscale({True: "log", False: "linear"}[self.log]) 202 ax.set_xlim(x.min(), x.max()) 203 ax.set_xlabel(self.ds._get_field_info(self.field).get_label()) 204 ax.set_ylabel(r"$\mathrm{alpha}$") 205 ax.set_ylim(y.max() * 1.0e-3, y.max() * 2) 206 207 if fn is None: 208 from IPython.core.display import Image 209 210 f = BytesIO() 211 canvas.print_figure(f) 212 f.seek(0) 213 img = f.read() 214 return Image(img) 215 else: 216 fig.savefig(fn) 217 218 def setup_profile(self, profile_field=None, profile_weight=None): 219 if profile_field is None: 220 profile_field = "cell_volume" 221 prof = create_profile( 222 self.ds.all_data(), 223 self.field, 224 profile_field, 225 n_bins=128, 226 extrema={self.field: self.bounds}, 227 weight_field=profile_weight, 228 logs={self.field: self.log}, 229 ) 230 self.profiles[self.field] = prof 231 return 232