1import warnings
2import weakref
3from typing import List, Tuple
4
5import numpy as np
6
7import yt.geometry.particle_deposit as particle_deposit
8from yt.config import ytcfg
9from yt.data_objects.selection_objects.data_selection_objects import (
10    YTSelectionContainer,
11)
12from yt.funcs import is_sequence
13from yt.geometry.selection_routines import convert_mask_to_indices
14from yt.units.yt_array import YTArray
15from yt.utilities.exceptions import (
16    YTFieldTypeNotFound,
17    YTParticleDepositionNotImplemented,
18)
19from yt.utilities.lib.interpolators import ghost_zone_interpolate
20from yt.utilities.lib.mesh_utilities import clamp_edges
21from yt.utilities.nodal_data_utils import get_nodal_slices
22
23RECONSTRUCT_INDEX = bool(ytcfg.get("yt", "reconstruct_index"))
24
25
26class AMRGridPatch(YTSelectionContainer):
27    _spatial = True
28    _num_ghost_zones = 0
29    _grids = None
30    _id_offset = 1
31    _cache_mask = True
32
33    _type_name = "grid"
34    _skip_add = True
35    _con_args = ("id", "filename")
36    _container_fields = (
37        ("index", "dx"),
38        ("index", "dy"),
39        ("index", "dz"),
40        ("index", "x"),
41        ("index", "y"),
42        ("index", "z"),
43    )
44    OverlappingSiblings = None
45
46    def __init__(self, id, filename=None, index=None):
47        super().__init__(index.dataset, None)
48        self.id = id
49        self._child_mask = self._child_indices = self._child_index_mask = None
50        self.ds = index.dataset
51        self._index = weakref.proxy(index)
52        self.start_index = None
53        self.filename = filename
54        self._last_mask = None
55        self._last_count = -1
56        self._last_selector_id = None
57
58    def get_global_startindex(self):
59        """
60        Return the integer starting index for each dimension at the current
61        level.
62
63        """
64        if self.start_index is not None:
65            return self.start_index
66        if self.Parent is None:
67            left = self.LeftEdge.d - self.ds.domain_left_edge.d
68            start_index = left / self.dds.d
69            return np.rint(start_index).astype("int64").ravel()
70
71        pdx = self.Parent.dds.d
72        di = np.rint((self.LeftEdge.d - self.Parent.LeftEdge.d) / pdx)
73        start_index = self.Parent.get_global_startindex() + di
74        self.start_index = (start_index * self.ds.refine_by).astype("int64").ravel()
75        return self.start_index
76
77    def __getitem__(self, key):
78        tr = super().__getitem__(key)
79        try:
80            fields = self._determine_fields(key)
81        except YTFieldTypeNotFound:
82            return tr
83        finfo = self.ds._get_field_info(*fields[0])
84        if not finfo.sampling_type == "particle":
85            num_nodes = 2 ** sum(finfo.nodal_flag)
86            new_shape = list(self.ActiveDimensions)
87            if num_nodes > 1:
88                new_shape += [num_nodes]
89            return tr.reshape(new_shape)
90        return tr
91
92    def convert(self, datatype):
93        """
94        This will attempt to convert a given unit to cgs from code units. It
95        either returns the multiplicative factor or throws a KeyError.
96
97        """
98        return self.ds[datatype]
99
100    @property
101    def shape(self):
102        return self.ActiveDimensions
103
104    def _reshape_vals(self, arr):
105        if len(arr.shape) == 3:
106            return arr
107        return arr.reshape(self.ActiveDimensions, order="C")
108
109    def _generate_container_field(self, field):
110        if self._current_chunk is None:
111            self.index._identify_base_chunk(self)
112        if field == ("index", "dx"):
113            tr = self._current_chunk.fwidth[:, 0]
114        elif field == ("index", "dy"):
115            tr = self._current_chunk.fwidth[:, 1]
116        elif field == ("index", "dz"):
117            tr = self._current_chunk.fwidth[:, 2]
118        elif field == ("index", "x"):
119            tr = self._current_chunk.fcoords[:, 0]
120        elif field == ("index", "y"):
121            tr = self._current_chunk.fcoords[:, 1]
122        elif field == ("index", "z"):
123            tr = self._current_chunk.fcoords[:, 2]
124        return self._reshape_vals(tr)
125
126    def _setup_dx(self):
127        # So first we figure out what the index is.  We don't assume
128        # that dx=dy=dz, at least here.  We probably do elsewhere.
129        id = self.id - self._id_offset
130        ds = self.ds
131        index = self.index
132        if self.Parent is not None:
133            if not hasattr(self.Parent, "dds"):
134                self.Parent._setup_dx()
135            self.dds = self.Parent.dds.d / self.ds.refine_by
136        else:
137            LE, RE = (index.grid_left_edge[id, :].d, index.grid_right_edge[id, :].d)
138            self.dds = (RE - LE) / self.ActiveDimensions
139        if self.ds.dimensionality < 3:
140            self.dds[2] = ds.domain_right_edge[2] - ds.domain_left_edge[2]
141        elif self.ds.dimensionality < 2:
142            self.dds[1] = ds.domain_right_edge[1] - ds.domain_left_edge[1]
143        self.dds = self.dds.view(YTArray)
144        self.dds.units = self.index.grid_left_edge.units
145
146    def __repr__(self):
147        return "AMRGridPatch_%04i" % (self.id)
148
149    def __int__(self):
150        return self.id
151
152    def clear_data(self):
153        """
154        Clear out the following things: child_mask, child_indices, all fields,
155        all field parameters.
156
157        """
158        super().clear_data()
159        self._setup_dx()
160
161    def _prepare_grid(self):
162        """Copies all the appropriate attributes from the index."""
163        # This is definitely the slowest part of generating the index
164        # Now we give it pointers to all of its attributes
165        # Note that to keep in line with Enzo, we have broken PEP-8
166        h = self.index  # cache it
167        my_ind = self.id - self._id_offset
168        self.ActiveDimensions = h.grid_dimensions[my_ind]
169        self.LeftEdge = h.grid_left_edge[my_ind]
170        self.RightEdge = h.grid_right_edge[my_ind]
171        # This can be expensive so we allow people to disable this behavior
172        # via a config option
173        if RECONSTRUCT_INDEX:
174            if is_sequence(self.Parent) and len(self.Parent) > 0:
175                p = self.Parent[0]
176            else:
177                p = self.Parent
178            if p is not None and p != []:
179                # clamp grid edges to an integer multiple of the parent cell
180                # width
181                clamp_edges(self.LeftEdge, p.LeftEdge, p.dds)
182                clamp_edges(self.RightEdge, p.RightEdge, p.dds)
183        h.grid_levels[my_ind, 0] = self.Level
184        # This might be needed for streaming formats
185        # self.Time = h.gridTimes[my_ind,0]
186        self.NumberOfParticles = h.grid_particle_count[my_ind, 0]
187
188    def get_position(self, index):
189        """Returns center position of an *index*."""
190        pos = (index + 0.5) * self.dds + self.LeftEdge
191        return pos
192
193    def _fill_child_mask(self, child, mask, tofill, dlevel=1):
194        rf = self.ds.refine_by
195        if dlevel != 1:
196            rf = rf ** dlevel
197        gi, cgi = self.get_global_startindex(), child.get_global_startindex()
198        startIndex = np.maximum(0, cgi // rf - gi)
199        endIndex = np.minimum(
200            (cgi + child.ActiveDimensions) // rf - gi, self.ActiveDimensions
201        )
202        endIndex += startIndex == endIndex
203        mask[
204            startIndex[0] : endIndex[0],
205            startIndex[1] : endIndex[1],
206            startIndex[2] : endIndex[2],
207        ] = tofill
208
209    @property
210    def child_mask(self):
211        """
212        Generates self.child_mask, which is zero where child grids exist (and
213        thus, where higher resolution data is available).
214
215        """
216        child_mask = np.ones(self.ActiveDimensions, "bool")
217        for child in self.Children:
218            self._fill_child_mask(child, child_mask, 0)
219        for sibling in self.OverlappingSiblings or []:
220            self._fill_child_mask(sibling, child_mask, 0, dlevel=0)
221        return child_mask
222
223    @property
224    def child_indices(self):
225        return self.child_mask == 0
226
227    @property
228    def child_index_mask(self):
229        """
230        Generates self.child_index_mask, which is -1 where there is no child,
231        and otherwise has the ID of the grid that resides there.
232
233        """
234        child_index_mask = np.zeros(self.ActiveDimensions, "int32") - 1
235        for child in self.Children:
236            self._fill_child_mask(child, child_index_mask, child.id)
237        for sibling in self.OverlappingSiblings or []:
238            self._fill_child_mask(sibling, child_index_mask, sibling.id, dlevel=0)
239        return child_index_mask
240
241    def retrieve_ghost_zones(self, n_zones, fields, all_levels=False, smoothed=False):
242        # We will attempt this by creating a datacube that is exactly bigger
243        # than the grid by nZones*dx in each direction
244        nl = self.get_global_startindex() - n_zones
245        new_left_edge = nl * self.dds + self.ds.domain_left_edge
246
247        # Something different needs to be done for the root grid, though
248        level = self.Level
249        if all_levels:
250            level = self.index.max_level + 1
251        kwargs = {
252            "dims": self.ActiveDimensions + 2 * n_zones,
253            "num_ghost_zones": n_zones,
254            "use_pbar": False,
255            "fields": fields,
256        }
257        # This should update the arguments to set the field parameters to be
258        # those of this grid.
259        field_parameters = {}
260        field_parameters.update(self.field_parameters)
261        if smoothed:
262            cube = self.ds.smoothed_covering_grid(
263                level, new_left_edge, field_parameters=field_parameters, **kwargs
264            )
265        else:
266            cube = self.ds.covering_grid(
267                level, new_left_edge, field_parameters=field_parameters, **kwargs
268            )
269        cube._base_grid = self
270        return cube
271
272    def get_vertex_centered_data(
273        self,
274        fields: List[Tuple[str, str]],
275        smoothed: bool = True,
276        no_ghost: bool = False,
277    ):
278        _old_api = isinstance(fields, (str, tuple))
279        if _old_api:
280            message = (
281                "get_vertex_centered_data() requires list of fields, rather than "
282                "a single field as an argument."
283            )
284            warnings.warn(message, DeprecationWarning, stacklevel=2)
285            fields = [fields]
286
287        # Make sure the field list has only unique entries
288        fields = list(set(fields))
289        new_fields = {}
290        for field in fields:
291            finfo = self.ds._get_field_info(field)
292            new_fields[field] = self.ds.arr(
293                np.zeros(self.ActiveDimensions + 1), finfo.output_units
294            )
295        if no_ghost:
296            for field in fields:
297                # Ensure we have the native endianness in this array.  Avoid making
298                # a copy if possible.
299                old_field = np.asarray(self[field], dtype="=f8")
300                # We'll use the ghost zone routine, which will naturally
301                # extrapolate here.
302                input_left = np.array([0.5, 0.5, 0.5], dtype="float64")
303                output_left = np.array([0.0, 0.0, 0.0], dtype="float64")
304                # rf = 1 here
305                ghost_zone_interpolate(
306                    1, old_field, input_left, new_fields[field], output_left
307                )
308        else:
309            cg = self.retrieve_ghost_zones(1, fields, smoothed=smoothed)
310            for field in fields:
311                src = cg[field].in_units(new_fields[field].units).d
312                dest = new_fields[field].d
313                np.add(dest, src[1:, 1:, 1:], dest)
314                np.add(dest, src[:-1, 1:, 1:], dest)
315                np.add(dest, src[1:, :-1, 1:], dest)
316                np.add(dest, src[1:, 1:, :-1], dest)
317                np.add(dest, src[:-1, 1:, :-1], dest)
318                np.add(dest, src[1:, :-1, :-1], dest)
319                np.add(dest, src[:-1, :-1, 1:], dest)
320                np.add(dest, src[:-1, :-1, :-1], dest)
321                np.multiply(dest, 0.125, dest)
322
323        if _old_api:
324            return new_fields[fields[0]]
325        return new_fields
326
327    def select_icoords(self, dobj):
328        mask = self._get_selector_mask(dobj.selector)
329        if mask is None:
330            return np.empty((0, 3), dtype="int64")
331        coords = convert_mask_to_indices(mask, self._last_count)
332        coords += self.get_global_startindex()[None, :]
333        return coords
334
335    def select_fcoords(self, dobj):
336        mask = self._get_selector_mask(dobj.selector)
337        if mask is None:
338            return np.empty((0, 3), dtype="float64")
339        coords = convert_mask_to_indices(mask, self._last_count).astype("float64")
340        coords += 0.5
341        coords *= self.dds[None, :]
342        coords += self.LeftEdge[None, :]
343        return coords
344
345    def select_fwidth(self, dobj):
346        count = self.count(dobj.selector)
347        if count == 0:
348            return np.empty((0, 3), dtype="float64")
349        coords = np.empty((count, 3), dtype="float64")
350        for axis in range(3):
351            coords[:, axis] = self.dds[axis]
352        return coords
353
354    def select_ires(self, dobj):
355        mask = self._get_selector_mask(dobj.selector)
356        if mask is None:
357            return np.empty(0, dtype="int64")
358        coords = np.empty(self._last_count, dtype="int64")
359        coords[:] = self.Level
360        return coords
361
362    def select_tcoords(self, dobj):
363        dt, t = dobj.selector.get_dt(self)
364        return dt, t
365
366    def smooth(self, *args, **kwargs):
367        raise NotImplementedError
368
369    def particle_operation(self, *args, **kwargs):
370        raise NotImplementedError
371
372    def deposit(self, positions, fields=None, method=None, kernel_name="cubic"):
373        # Here we perform our particle deposition.
374        cls = getattr(particle_deposit, f"deposit_{method}", None)
375        if cls is None:
376            raise YTParticleDepositionNotImplemented(method)
377        # We allocate number of zones, not number of octs. Everything
378        # inside this is Fortran ordered because of the ordering in the
379        # octree deposit routines, so we reverse it here to match the
380        # convention there
381        nvals = tuple(self.ActiveDimensions[::-1])
382        # append a dummy dimension because we are only depositing onto
383        # one grid
384        op = cls(nvals + (1,), kernel_name)
385        op.initialize()
386        op.process_grid(self, positions, fields)
387        vals = op.finalize()
388        if vals is None:
389            return
390        # Fortran-ordered, so transpose.
391        vals = vals.transpose()
392        # squeeze dummy dimension we appended above
393        return np.squeeze(vals, axis=0)
394
395    def select_blocks(self, selector):
396        mask = self._get_selector_mask(selector)
397        yield self, mask
398
399    def _get_selector_mask(self, selector):
400        if self._cache_mask and hash(selector) == self._last_selector_id:
401            mask = self._last_mask
402        else:
403            mask = selector.fill_mask(self)
404            if self._cache_mask:
405                self._last_mask = mask
406            self._last_selector_id = hash(selector)
407            if mask is None:
408                self._last_count = 0
409            else:
410                self._last_count = mask.sum()
411        return mask
412
413    def select(self, selector, source, dest, offset):
414        mask = self._get_selector_mask(selector)
415        count = self.count(selector)
416        if count == 0:
417            return 0
418        dim = np.squeeze(self.ds.dimensionality)
419        nodal_flag = source.shape[:dim] - self.ActiveDimensions[:dim]
420        if sum(nodal_flag) == 0:
421            dest[offset : offset + count] = source[mask]
422        else:
423            slices = get_nodal_slices(source.shape, nodal_flag, dim)
424            for i, sl in enumerate(slices):
425                dest[offset : offset + count, i] = source[tuple(sl)][np.squeeze(mask)]
426        return count
427
428    def count(self, selector):
429        mask = self._get_selector_mask(selector)
430        if mask is None:
431            return 0
432        return self._last_count
433
434    def count_particles(self, selector, x, y, z):
435        # We don't cache the selector results
436        count = selector.count_points(x, y, z, 0.0)
437        return count
438
439    def select_particles(self, selector, x, y, z):
440        mask = selector.select_points(x, y, z, 0.0)
441        return mask
442