1import os
2import weakref
3
4import numpy as np
5
6from yt.geometry.geometry_handler import Index, YTDataChunk
7from yt.utilities.lib.mesh_utilities import smallest_fwidth
8from yt.utilities.logger import ytLogger as mylog
9
10
11class UnstructuredIndex(Index):
12    """The Index subclass for unstructured and hexahedral mesh datasets."""
13
14    _unsupported_objects = ("proj", "covering_grid", "smoothed_covering_grid")
15
16    def __init__(self, ds, dataset_type):
17        self.dataset_type = dataset_type
18        self.dataset = weakref.proxy(ds)
19        self.index_filename = self.dataset.parameter_filename
20        self.directory = os.path.dirname(self.index_filename)
21        self.float_type = np.float64
22        super().__init__(ds, dataset_type)
23
24    def _setup_geometry(self):
25        mylog.debug("Initializing Unstructured Mesh Geometry Handler.")
26        self._initialize_mesh()
27
28    def get_smallest_dx(self):
29        """
30        Returns (in code units) the smallest cell size in the simulation.
31        """
32        dx = min(
33            smallest_fwidth(
34                mesh.connectivity_coords, mesh.connectivity_indices, mesh._index_offset
35            )
36            for mesh in self.meshes
37        )
38        return dx
39
40    def convert(self, unit):
41        return self.dataset.conversion_factors[unit]
42
43    def _initialize_mesh(self):
44        raise NotImplementedError
45
46    def _identify_base_chunk(self, dobj):
47        if getattr(dobj, "_chunk_info", None) is None:
48            dobj._chunk_info = self.meshes
49        if getattr(dobj, "size", None) is None:
50            dobj.size = self._count_selection(dobj)
51        dobj._current_chunk = list(self._chunk_all(dobj))[0]
52
53    def _count_selection(self, dobj, meshes=None):
54        if meshes is None:
55            meshes = dobj._chunk_info
56        count = sum(m.count(dobj.selector) for m in meshes)
57        return count
58
59    def _chunk_all(self, dobj, cache=True):
60        oobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info)
61        yield YTDataChunk(dobj, "all", oobjs, dobj.size, cache)
62
63    def _chunk_spatial(self, dobj, ngz, sort=None, preload_fields=None):
64        sobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info)
65        # This is where we will perform cutting of the Octree and
66        # load-balancing.  That may require a specialized selector object to
67        # cut based on some space-filling curve index.
68        for og in sobjs:
69            if ngz > 0:
70                g = og.retrieve_ghost_zones(ngz, [], smoothed=True)
71            else:
72                g = og
73            size = self._count_selection(dobj, [og])
74            if size == 0:
75                continue
76            yield YTDataChunk(dobj, "spatial", [g], size)
77
78    def _chunk_io(self, dobj, cache=True, local_only=False):
79        oobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info)
80        for subset in oobjs:
81            s = self._count_selection(dobj, oobjs)
82            yield YTDataChunk(dobj, "io", [subset], s, cache=cache)
83