1import weakref
2
3from yt.funcs import obj_length
4from yt.units.yt_array import YTQuantity
5from yt.utilities.exceptions import YTDimensionalityError, YTFieldNotParseable
6from yt.visualization.line_plot import LineBuffer
7
8from .data_containers import _get_ipython_key_completion
9
10
11class RegionExpression:
12    _all_data = None
13
14    def __init__(self, ds):
15        self.ds = weakref.proxy(ds)
16
17    @property
18    def all_data(self):
19        if self._all_data is None:
20            self._all_data = self.ds.all_data()
21        return self._all_data
22
23    def __getitem__(self, item):
24        # At first, we will only implement this as accepting a slice that is
25        # (optionally) unitful corresponding to a specific set of coordinates
26        # that result in a rectangular prism or a slice.
27        try:
28            return self.all_data[item]
29        except (TypeError, YTFieldNotParseable):
30            pass
31
32        if isinstance(item, slice):
33            if obj_length(item.start) == 3 and obj_length(item.stop) == 3:
34                # This is for a ray that is not orthogonal to an axis.
35                # it's straightforward to do this, so we create a ray
36                # and drop out here.
37                return self._create_ray(item)
38            else:
39                # This is for the case where we give a slice as an index; one
40                # possible use case of this would be where we supply something
41                # like ds.r[::256j] .  This would be expanded, implicitly into
42                # ds.r[::256j, ::256j, ::256j].  Other cases would be if we do
43                # ds.r[0.1:0.9] where it will be expanded along all dimensions.
44                item = tuple(item for _ in range(self.ds.dimensionality))
45
46        if item is Ellipsis:
47            item = (Ellipsis,)
48
49        # from this point, item is implicitly assumed to be iterable
50        if Ellipsis in item:
51            # expand "..." into the appropriate number of ":"
52            item = list(item)
53            idx = item.index(Ellipsis)
54            item.pop(idx)
55            if Ellipsis in item:
56                # this error mimics numpy's
57                raise IndexError("an index can only have a single ellipsis ('...')")
58            while len(item) < self.ds.dimensionality:
59                item.insert(idx, slice(None))
60
61        if len(item) != self.ds.dimensionality:
62            # Not the right specification, and we don't want to do anything
63            # implicitly.  Note that this happens *after* the implicit expansion
64            # of a single slice.
65            raise YTDimensionalityError(len(item), self.ds.dimensionality)
66
67        # OK, now we need to look at our slices.  How many are a specific
68        # coordinate?
69
70        nslices = sum(isinstance(v, slice) for v in item)
71        if nslices == 0:
72            return self._create_point(item)
73        elif nslices == 1:
74            return self._create_ortho_ray(item)
75        elif nslices == 2:
76            return self._create_slice(item)
77        else:
78            if all(s.start is s.stop is s.step is None for s in item):
79                return self.all_data
80            return self._create_region(item)
81
82    def _ipython_key_completions_(self):
83        return _get_ipython_key_completion(self.ds)
84
85    def _spec_to_value(self, input):
86        if isinstance(input, tuple):
87            v = self.ds.quan(input[0], input[1]).to("code_length")
88        elif isinstance(input, YTQuantity):
89            v = self.ds.quan(input).to("code_length")
90        else:
91            v = self.ds.quan(input, "code_length")
92        return v
93
94    def _create_slice(self, slice_tuple):
95        # This is somewhat more complex because we want to allow for slicing
96        # in one dimension but also *not* using the entire domain; for instance
97        # this means we allow something like ds.r[0.5, 0.1:0.4, 0.1:0.4].
98        axis = None
99        new_slice = []
100        for ax, v in enumerate(slice_tuple):
101            if not isinstance(v, slice):
102                if axis is not None:
103                    raise RuntimeError
104                axis = ax
105                coord = self._spec_to_value(v)
106                new_slice.append(slice(None, None, None))
107            else:
108                new_slice.append(v)
109        # This new slice doesn't need to be a tuple
110        dim = self.ds.dimensionality
111        if dim < 2:
112            raise ValueError(
113                "Can not create a slice from data with dimensionality '%d'" % dim
114            )
115        if dim == 2:
116            coord = self.ds.domain_center[2]
117            axis = 2
118        source = self._create_region(new_slice)
119        sl = self.ds.slice(axis, coord, data_source=source)
120        # Now, there's the possibility that what we're also seeing here
121        # includes some steps, which would be for getting back a fixed
122        # resolution buffer.  We check for that by checking if we have
123        # exactly two imaginary steps.
124        xax = self.ds.coordinates.x_axis[axis]
125        yax = self.ds.coordinates.y_axis[axis]
126        if (
127            getattr(new_slice[xax].step, "imag", 0.0) != 0.0
128            and getattr(new_slice[yax].step, "imag", 0.0) != 0.0
129        ):
130            # We now need to convert to a fixed res buffer.
131            # We'll do this by getting the x/y axes, and then using that.
132            width = source.right_edge[xax] - source.left_edge[xax]
133            height = source.right_edge[yax] - source.left_edge[yax]
134            # Make a resolution tuple with
135            resolution = (int(new_slice[xax].step.imag), int(new_slice[yax].step.imag))
136            sl = sl.to_frb(width=width, resolution=resolution, height=height)
137        return sl
138
139    def _slice_to_edges(self, ax, val):
140        if val.start is None:
141            l = self.ds.domain_left_edge[ax]
142        else:
143            l = self._spec_to_value(val.start)
144        if val.stop is None:
145            r = self.ds.domain_right_edge[ax]
146        else:
147            r = self._spec_to_value(val.stop)
148        if r < l:
149            raise RuntimeError
150        return l, r
151
152    def _create_region(self, bounds_tuple):
153        left_edge = self.ds.domain_left_edge.copy()
154        right_edge = self.ds.domain_right_edge.copy()
155        dims = []
156        for ax, b in enumerate(bounds_tuple):
157            l, r = self._slice_to_edges(ax, b)
158            left_edge[ax] = l
159            right_edge[ax] = r
160            dims.append(getattr(b.step, "imag", None))
161        center = [(cl + cr) / 2.0 for cl, cr in zip(left_edge, right_edge)]
162        if all(d is not None for d in dims):
163            return self.ds.arbitrary_grid(left_edge, right_edge, dims)
164        return self.ds.region(center, left_edge, right_edge)
165
166    def _create_point(self, point_tuple):
167        coord = [self._spec_to_value(p) for p in point_tuple]
168        return self.ds.point(coord)
169
170    def _create_ray(self, ray_slice):
171        start_point = [self._spec_to_value(v) for v in ray_slice.start]
172        end_point = [self._spec_to_value(v) for v in ray_slice.stop]
173        if getattr(ray_slice.step, "imag", 0.0) != 0.0:
174            return LineBuffer(self.ds, start_point, end_point, int(ray_slice.step.imag))
175        else:
176            return self.ds.ray(start_point, end_point)
177
178    def _create_ortho_ray(self, ray_tuple):
179        axis = None
180        new_slice = []
181        coord = []
182        npoints = 0
183        start_point = []
184        end_point = []
185        for ax, v in enumerate(ray_tuple):
186            if not isinstance(v, slice):
187                val = self._spec_to_value(v)
188                coord.append(val)
189                new_slice.append(slice(None, None, None))
190                start_point.append(val)
191                end_point.append(val)
192            else:
193                if axis is not None:
194                    raise RuntimeError
195                if getattr(v.step, "imag", 0.0) != 0.0:
196                    npoints = int(v.step.imag)
197                    xi = self._spec_to_value(v.start)
198                    xf = self._spec_to_value(v.stop)
199                    dx = (xf - xi) / npoints
200                    start_point.append(xi + 0.5 * dx)
201                    end_point.append(xf - 0.5 * dx)
202                else:
203                    axis = ax
204                    new_slice.append(v)
205        if npoints > 0:
206            ray = LineBuffer(self.ds, start_point, end_point, npoints)
207        else:
208            if axis == 1:
209                coord = [coord[1], coord[0]]
210            source = self._create_region(new_slice)
211            ray = self.ds.ortho_ray(axis, coord, data_source=source)
212        return ray
213