1import weakref 2 3from yt.funcs import obj_length 4from yt.units.yt_array import YTQuantity 5from yt.utilities.exceptions import YTDimensionalityError, YTFieldNotParseable 6from yt.visualization.line_plot import LineBuffer 7 8from .data_containers import _get_ipython_key_completion 9 10 11class RegionExpression: 12 _all_data = None 13 14 def __init__(self, ds): 15 self.ds = weakref.proxy(ds) 16 17 @property 18 def all_data(self): 19 if self._all_data is None: 20 self._all_data = self.ds.all_data() 21 return self._all_data 22 23 def __getitem__(self, item): 24 # At first, we will only implement this as accepting a slice that is 25 # (optionally) unitful corresponding to a specific set of coordinates 26 # that result in a rectangular prism or a slice. 27 try: 28 return self.all_data[item] 29 except (TypeError, YTFieldNotParseable): 30 pass 31 32 if isinstance(item, slice): 33 if obj_length(item.start) == 3 and obj_length(item.stop) == 3: 34 # This is for a ray that is not orthogonal to an axis. 35 # it's straightforward to do this, so we create a ray 36 # and drop out here. 37 return self._create_ray(item) 38 else: 39 # This is for the case where we give a slice as an index; one 40 # possible use case of this would be where we supply something 41 # like ds.r[::256j] . This would be expanded, implicitly into 42 # ds.r[::256j, ::256j, ::256j]. Other cases would be if we do 43 # ds.r[0.1:0.9] where it will be expanded along all dimensions. 44 item = tuple(item for _ in range(self.ds.dimensionality)) 45 46 if item is Ellipsis: 47 item = (Ellipsis,) 48 49 # from this point, item is implicitly assumed to be iterable 50 if Ellipsis in item: 51 # expand "..." into the appropriate number of ":" 52 item = list(item) 53 idx = item.index(Ellipsis) 54 item.pop(idx) 55 if Ellipsis in item: 56 # this error mimics numpy's 57 raise IndexError("an index can only have a single ellipsis ('...')") 58 while len(item) < self.ds.dimensionality: 59 item.insert(idx, slice(None)) 60 61 if len(item) != self.ds.dimensionality: 62 # Not the right specification, and we don't want to do anything 63 # implicitly. Note that this happens *after* the implicit expansion 64 # of a single slice. 65 raise YTDimensionalityError(len(item), self.ds.dimensionality) 66 67 # OK, now we need to look at our slices. How many are a specific 68 # coordinate? 69 70 nslices = sum(isinstance(v, slice) for v in item) 71 if nslices == 0: 72 return self._create_point(item) 73 elif nslices == 1: 74 return self._create_ortho_ray(item) 75 elif nslices == 2: 76 return self._create_slice(item) 77 else: 78 if all(s.start is s.stop is s.step is None for s in item): 79 return self.all_data 80 return self._create_region(item) 81 82 def _ipython_key_completions_(self): 83 return _get_ipython_key_completion(self.ds) 84 85 def _spec_to_value(self, input): 86 if isinstance(input, tuple): 87 v = self.ds.quan(input[0], input[1]).to("code_length") 88 elif isinstance(input, YTQuantity): 89 v = self.ds.quan(input).to("code_length") 90 else: 91 v = self.ds.quan(input, "code_length") 92 return v 93 94 def _create_slice(self, slice_tuple): 95 # This is somewhat more complex because we want to allow for slicing 96 # in one dimension but also *not* using the entire domain; for instance 97 # this means we allow something like ds.r[0.5, 0.1:0.4, 0.1:0.4]. 98 axis = None 99 new_slice = [] 100 for ax, v in enumerate(slice_tuple): 101 if not isinstance(v, slice): 102 if axis is not None: 103 raise RuntimeError 104 axis = ax 105 coord = self._spec_to_value(v) 106 new_slice.append(slice(None, None, None)) 107 else: 108 new_slice.append(v) 109 # This new slice doesn't need to be a tuple 110 dim = self.ds.dimensionality 111 if dim < 2: 112 raise ValueError( 113 "Can not create a slice from data with dimensionality '%d'" % dim 114 ) 115 if dim == 2: 116 coord = self.ds.domain_center[2] 117 axis = 2 118 source = self._create_region(new_slice) 119 sl = self.ds.slice(axis, coord, data_source=source) 120 # Now, there's the possibility that what we're also seeing here 121 # includes some steps, which would be for getting back a fixed 122 # resolution buffer. We check for that by checking if we have 123 # exactly two imaginary steps. 124 xax = self.ds.coordinates.x_axis[axis] 125 yax = self.ds.coordinates.y_axis[axis] 126 if ( 127 getattr(new_slice[xax].step, "imag", 0.0) != 0.0 128 and getattr(new_slice[yax].step, "imag", 0.0) != 0.0 129 ): 130 # We now need to convert to a fixed res buffer. 131 # We'll do this by getting the x/y axes, and then using that. 132 width = source.right_edge[xax] - source.left_edge[xax] 133 height = source.right_edge[yax] - source.left_edge[yax] 134 # Make a resolution tuple with 135 resolution = (int(new_slice[xax].step.imag), int(new_slice[yax].step.imag)) 136 sl = sl.to_frb(width=width, resolution=resolution, height=height) 137 return sl 138 139 def _slice_to_edges(self, ax, val): 140 if val.start is None: 141 l = self.ds.domain_left_edge[ax] 142 else: 143 l = self._spec_to_value(val.start) 144 if val.stop is None: 145 r = self.ds.domain_right_edge[ax] 146 else: 147 r = self._spec_to_value(val.stop) 148 if r < l: 149 raise RuntimeError 150 return l, r 151 152 def _create_region(self, bounds_tuple): 153 left_edge = self.ds.domain_left_edge.copy() 154 right_edge = self.ds.domain_right_edge.copy() 155 dims = [] 156 for ax, b in enumerate(bounds_tuple): 157 l, r = self._slice_to_edges(ax, b) 158 left_edge[ax] = l 159 right_edge[ax] = r 160 dims.append(getattr(b.step, "imag", None)) 161 center = [(cl + cr) / 2.0 for cl, cr in zip(left_edge, right_edge)] 162 if all(d is not None for d in dims): 163 return self.ds.arbitrary_grid(left_edge, right_edge, dims) 164 return self.ds.region(center, left_edge, right_edge) 165 166 def _create_point(self, point_tuple): 167 coord = [self._spec_to_value(p) for p in point_tuple] 168 return self.ds.point(coord) 169 170 def _create_ray(self, ray_slice): 171 start_point = [self._spec_to_value(v) for v in ray_slice.start] 172 end_point = [self._spec_to_value(v) for v in ray_slice.stop] 173 if getattr(ray_slice.step, "imag", 0.0) != 0.0: 174 return LineBuffer(self.ds, start_point, end_point, int(ray_slice.step.imag)) 175 else: 176 return self.ds.ray(start_point, end_point) 177 178 def _create_ortho_ray(self, ray_tuple): 179 axis = None 180 new_slice = [] 181 coord = [] 182 npoints = 0 183 start_point = [] 184 end_point = [] 185 for ax, v in enumerate(ray_tuple): 186 if not isinstance(v, slice): 187 val = self._spec_to_value(v) 188 coord.append(val) 189 new_slice.append(slice(None, None, None)) 190 start_point.append(val) 191 end_point.append(val) 192 else: 193 if axis is not None: 194 raise RuntimeError 195 if getattr(v.step, "imag", 0.0) != 0.0: 196 npoints = int(v.step.imag) 197 xi = self._spec_to_value(v.start) 198 xf = self._spec_to_value(v.stop) 199 dx = (xf - xi) / npoints 200 start_point.append(xi + 0.5 * dx) 201 end_point.append(xf - 0.5 * dx) 202 else: 203 axis = ax 204 new_slice.append(v) 205 if npoints > 0: 206 ray = LineBuffer(self.ds, start_point, end_point, npoints) 207 else: 208 if axis == 1: 209 coord = [coord[1], coord[0]] 210 source = self._create_region(new_slice) 211 ray = self.ds.ortho_ray(axis, coord, data_source=source) 212 return ray 213