1import os
2import time
3import uuid
4import weakref
5from itertools import chain, product, repeat
6from numbers import Number as numeric_type
7
8import numpy as np
9from more_itertools import always_iterable
10
11from yt.data_objects.field_data import YTFieldData
12from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
13from yt.data_objects.index_subobjects.octree_subset import OctreeSubset
14from yt.data_objects.index_subobjects.unstructured_mesh import (
15    SemiStructuredMesh,
16    UnstructuredMesh,
17)
18from yt.data_objects.particle_unions import ParticleUnion
19from yt.data_objects.static_output import Dataset, ParticleFile
20from yt.data_objects.unions import MeshUnion
21from yt.frontends.sph.data_structures import SPHParticleIndex
22from yt.geometry.geometry_handler import YTDataChunk
23from yt.geometry.grid_geometry_handler import GridIndex
24from yt.geometry.oct_container import OctreeContainer
25from yt.geometry.oct_geometry_handler import OctreeIndex
26from yt.geometry.unstructured_mesh_handler import UnstructuredIndex
27from yt.units import YTQuantity
28from yt.utilities.io_handler import io_registry
29from yt.utilities.lib.cykdtree import PyKDTree
30from yt.utilities.lib.misc_utilities import get_box_grids_level
31from yt.utilities.lib.particle_kdtree_tools import (
32    estimate_density,
33    generate_smoothing_length,
34)
35from yt.utilities.logger import ytLogger as mylog
36
37from .definitions import process_data, set_particle_types
38from .fields import StreamFieldInfo
39
40
41class StreamGrid(AMRGridPatch):
42    """
43    Class representing a single In-memory Grid instance.
44    """
45
46    __slots__ = ["proc_num"]
47    _id_offset = 0
48
49    def __init__(self, id, index):
50        """
51        Returns an instance of StreamGrid with *id*, associated with *filename*
52        and *index*.
53        """
54        # All of the field parameters will be passed to us as needed.
55        AMRGridPatch.__init__(self, id, filename=None, index=index)
56        self._children_ids = []
57        self._parent_id = -1
58        self.Level = -1
59
60    def set_filename(self, filename):
61        pass
62
63    def __repr__(self):
64        return "StreamGrid_%04i" % (self.id)
65
66    @property
67    def Parent(self):
68        if self._parent_id == -1:
69            return None
70        return self.index.grids[self._parent_id - self._id_offset]
71
72    @property
73    def Children(self):
74        return [self.index.grids[cid - self._id_offset] for cid in self._children_ids]
75
76
77class StreamHandler:
78    def __init__(
79        self,
80        left_edges,
81        right_edges,
82        dimensions,
83        levels,
84        parent_ids,
85        particle_count,
86        processor_ids,
87        fields,
88        field_units,
89        code_units,
90        io=None,
91        particle_types=None,
92        periodicity=(True, True, True),
93    ):
94        if particle_types is None:
95            particle_types = {}
96        self.left_edges = np.array(left_edges)
97        self.right_edges = np.array(right_edges)
98        self.dimensions = dimensions
99        self.levels = levels
100        self.parent_ids = parent_ids
101        self.particle_count = particle_count
102        self.processor_ids = processor_ids
103        self.num_grids = self.levels.size
104        self.fields = fields
105        self.field_units = field_units
106        self.code_units = code_units
107        self.io = io
108        self.particle_types = particle_types
109        self.periodicity = periodicity
110
111    def get_fields(self):
112        return self.fields.all_fields
113
114    def get_particle_type(self, field):
115
116        if field in self.particle_types:
117            return self.particle_types[field]
118        else:
119            return False
120
121
122class StreamHierarchy(GridIndex):
123
124    grid = StreamGrid
125
126    def __init__(self, ds, dataset_type=None):
127        self.dataset_type = dataset_type
128        self.float_type = "float64"
129        self.dataset = weakref.proxy(ds)  # for _obtain_enzo
130        self.stream_handler = ds.stream_handler
131        self.float_type = "float64"
132        self.directory = os.getcwd()
133        GridIndex.__init__(self, ds, dataset_type)
134
135    def _count_grids(self):
136        self.num_grids = self.stream_handler.num_grids
137
138    def _parse_index(self):
139        self.grid_dimensions = self.stream_handler.dimensions
140        self.grid_left_edge[:] = self.stream_handler.left_edges
141        self.grid_right_edge[:] = self.stream_handler.right_edges
142        self.grid_levels[:] = self.stream_handler.levels
143        self.min_level = self.grid_levels.min()
144        self.grid_procs = self.stream_handler.processor_ids
145        self.grid_particle_count[:] = self.stream_handler.particle_count
146        mylog.debug("Copying reverse tree")
147        self.grids = []
148        # We enumerate, so it's 0-indexed id and 1-indexed pid
149        for id in range(self.num_grids):
150            self.grids.append(self.grid(id, self))
151            self.grids[id].Level = self.grid_levels[id, 0]
152        parent_ids = self.stream_handler.parent_ids
153        if parent_ids is not None:
154            reverse_tree = self.stream_handler.parent_ids.tolist()
155            # Initial setup:
156            for gid, pid in enumerate(reverse_tree):
157                if pid >= 0:
158                    self.grids[gid]._parent_id = pid
159                    self.grids[pid]._children_ids.append(self.grids[gid].id)
160        else:
161            mylog.debug("Reconstructing parent-child relationships")
162            self._reconstruct_parent_child()
163        self.max_level = self.grid_levels.max()
164        mylog.debug("Preparing grids")
165        temp_grids = np.empty(self.num_grids, dtype="object")
166        for i, grid in enumerate(self.grids):
167            if (i % 1e4) == 0:
168                mylog.debug("Prepared % 7i / % 7i grids", i, self.num_grids)
169            grid.filename = None
170            grid._prepare_grid()
171            grid._setup_dx()
172            grid.proc_num = self.grid_procs[i]
173            temp_grids[i] = grid
174        self.grids = temp_grids
175        mylog.debug("Prepared")
176
177    def _reconstruct_parent_child(self):
178        mask = np.empty(len(self.grids), dtype="int32")
179        mylog.debug("First pass; identifying child grids")
180        for i, grid in enumerate(self.grids):
181            get_box_grids_level(
182                self.grid_left_edge[i, :],
183                self.grid_right_edge[i, :],
184                self.grid_levels[i] + 1,
185                self.grid_left_edge,
186                self.grid_right_edge,
187                self.grid_levels,
188                mask,
189            )
190            ids = np.where(mask.astype("bool"))
191            grid._children_ids = ids[0]  # where is a tuple
192        mylog.debug("Second pass; identifying parents")
193        self.stream_handler.parent_ids = (
194            np.zeros(self.stream_handler.num_grids, "int64") - 1
195        )
196        for i, grid in enumerate(self.grids):  # Second pass
197            for child in grid.Children:
198                child._parent_id = i
199                # _id_offset = 0
200                self.stream_handler.parent_ids[child.id] = i
201
202    def _initialize_grid_arrays(self):
203        GridIndex._initialize_grid_arrays(self)
204        self.grid_procs = np.zeros((self.num_grids, 1), "int32")
205
206    def _detect_output_fields(self):
207        # NOTE: Because particle unions add to the actual field list, without
208        # having the keys in the field list itself, we need to double check
209        # here.
210        fl = set(self.stream_handler.get_fields())
211        fl.update(set(getattr(self, "field_list", [])))
212        self.field_list = list(fl)
213
214    def _populate_grid_objects(self):
215        for g in self.grids:
216            g._setup_dx()
217        self.max_level = self.grid_levels.max()
218
219    def _setup_data_io(self):
220        if self.stream_handler.io is not None:
221            self.io = self.stream_handler.io
222        else:
223            self.io = io_registry[self.dataset_type](self.ds)
224
225    def _reset_particle_count(self):
226        self.grid_particle_count[:] = self.stream_handler.particle_count
227        for i, grid in enumerate(self.grids):
228            grid.NumberOfParticles = self.grid_particle_count[i, 0]
229
230    def update_data(self, data):
231        """
232        Update the stream data with a new data dict. If fields already exist,
233        they will be replaced, but if they do not, they will be added. Fields
234        already in the stream but not part of the data dict will be left
235        alone.
236        """
237        particle_types = set_particle_types(data[0])
238
239        self.stream_handler.particle_types.update(particle_types)
240        self.ds._find_particle_types()
241
242        for i, grid in enumerate(self.grids):
243            field_units, gdata, number_of_particles = process_data(data[i])
244            self.stream_handler.particle_count[i] = number_of_particles
245            self.stream_handler.field_units.update(field_units)
246            for field in gdata:
247                if field in grid.field_data:
248                    grid.field_data.pop(field, None)
249                self.stream_handler.fields[grid.id][field] = gdata[field]
250
251        self._reset_particle_count()
252        # We only want to create a superset of fields here.
253        for field in self.ds.field_list:
254            if field[0] == "all":
255                self.ds.field_list.remove(field)
256        self._detect_output_fields()
257        self.ds.create_field_info()
258        mylog.debug("Creating Particle Union 'all'")
259        pu = ParticleUnion("all", list(self.ds.particle_types_raw))
260        self.ds.add_particle_union(pu)
261        self.ds.particle_types = tuple(set(self.ds.particle_types))
262
263
264class StreamDataset(Dataset):
265    _index_class = StreamHierarchy
266    _field_info_class = StreamFieldInfo
267    _dataset_type = "stream"
268
269    def __init__(
270        self,
271        stream_handler,
272        storage_filename=None,
273        geometry="cartesian",
274        unit_system="cgs",
275        default_species_fields=None,
276    ):
277        self.fluid_types += ("stream",)
278        self.geometry = geometry
279        self.stream_handler = stream_handler
280        self._find_particle_types()
281        name = f"InMemoryParameterFile_{uuid.uuid4().hex}"
282        from yt.data_objects.static_output import _cached_datasets
283
284        _cached_datasets[name] = self
285        Dataset.__init__(
286            self,
287            name,
288            self._dataset_type,
289            unit_system=unit_system,
290            default_species_fields=default_species_fields,
291        )
292
293    def _parse_parameter_file(self):
294        self.basename = self.stream_handler.name
295        self.parameters["CurrentTimeIdentifier"] = time.time()
296        self.unique_identifier = self.parameters["CurrentTimeIdentifier"]
297        self.domain_left_edge = self.stream_handler.domain_left_edge.copy()
298        self.domain_right_edge = self.stream_handler.domain_right_edge.copy()
299        self.refine_by = self.stream_handler.refine_by
300        self.dimensionality = self.stream_handler.dimensionality
301        self._periodicity = self.stream_handler.periodicity
302        self.domain_dimensions = self.stream_handler.domain_dimensions
303        self.current_time = self.stream_handler.simulation_time
304        self.gamma = 5.0 / 3.0
305        self.parameters["EOSType"] = -1
306        self.parameters["CosmologyHubbleConstantNow"] = 1.0
307        self.parameters["CosmologyCurrentRedshift"] = 1.0
308        self.parameters["HydroMethod"] = -1
309        if self.stream_handler.cosmology_simulation:
310            self.cosmological_simulation = 1
311            self.current_redshift = self.stream_handler.current_redshift
312            self.omega_lambda = self.stream_handler.omega_lambda
313            self.omega_matter = self.stream_handler.omega_matter
314            self.hubble_constant = self.stream_handler.hubble_constant
315        else:
316            self.current_redshift = 0.0
317            self.omega_lambda = 0.0
318            self.omega_matter = 0.0
319            self.hubble_constant = 0.0
320            self.cosmological_simulation = 0
321
322    def _set_units(self):
323        self.field_units = self.stream_handler.field_units
324
325    def _set_code_unit_attributes(self):
326        base_units = self.stream_handler.code_units
327        attrs = (
328            "length_unit",
329            "mass_unit",
330            "time_unit",
331            "velocity_unit",
332            "magnetic_unit",
333        )
334        cgs_units = ("cm", "g", "s", "cm/s", "gauss")
335        for unit, attr, cgs_unit in zip(base_units, attrs, cgs_units):
336            if isinstance(unit, str):
337                uq = self.quan(1.0, unit)
338            elif isinstance(unit, numeric_type):
339                uq = self.quan(unit, cgs_unit)
340            elif isinstance(unit, YTQuantity):
341                uq = unit
342            elif isinstance(unit, tuple):
343                uq = self.quan(unit[0], unit[1])
344            else:
345                raise RuntimeError(f"{attr} ({unit}) is invalid.")
346            setattr(self, attr, uq)
347
348    @classmethod
349    def _is_valid(cls, filename, *args, **kwargs):
350        return False
351
352    @property
353    def _skip_cache(self):
354        return True
355
356    def _find_particle_types(self):
357        particle_types = set()
358        for k, v in self.stream_handler.particle_types.items():
359            if v:
360                particle_types.add(k[0])
361        self.particle_types = tuple(particle_types)
362        self.particle_types_raw = self.particle_types
363
364
365class StreamDictFieldHandler(dict):
366    _additional_fields = ()
367
368    @property
369    def all_fields(self):
370        self_fields = chain.from_iterable(s.keys() for s in self.values())
371        self_fields = list(set(self_fields))
372        fields = list(self._additional_fields) + self_fields
373        fields = list(set(fields))
374        return fields
375
376
377class StreamParticleIndex(SPHParticleIndex):
378    def __init__(self, ds, dataset_type=None):
379        self.stream_handler = ds.stream_handler
380        super().__init__(ds, dataset_type)
381
382    def _setup_data_io(self):
383        if self.stream_handler.io is not None:
384            self.io = self.stream_handler.io
385        else:
386            self.io = io_registry[self.dataset_type](self.ds)
387
388    def update_data(self, data):
389        """
390        Update the stream data with a new data dict. If fields already exist,
391        they will be replaced, but if they do not, they will be added. Fields
392        already in the stream but not part of the data dict will be left
393        alone.
394        """
395        # Alias
396        ds = self.ds
397        handler = ds.stream_handler
398
399        # Preprocess
400        field_units, data, _ = process_data(data)
401        pdata = {}
402        for key in data.keys():
403            if not isinstance(key, tuple):
404                field = ("io", key)
405                mylog.debug("Reassigning '%s' to '%s'", key, field)
406            else:
407                field = key
408            pdata[field] = data[key]
409        data = pdata  # Drop reference count
410        particle_types = set_particle_types(data)
411
412        # Update particle types
413        handler.particle_types.update(particle_types)
414        ds._find_particle_types()
415
416        # Update fields
417        handler.field_units.update(field_units)
418        fields = handler.fields
419        for field in data.keys():
420            if field not in fields._additional_fields:
421                fields._additional_fields += (field,)
422        fields["stream_file"].update(data)
423
424        # Update field list
425        for field in self.ds.field_list:
426            if field[0] in ["all", "nbody"]:
427                self.ds.field_list.remove(field)
428        self._detect_output_fields()
429        self.ds.create_field_info()
430
431
432class StreamParticleFile(ParticleFile):
433    pass
434
435
436class StreamParticlesDataset(StreamDataset):
437    _index_class = StreamParticleIndex
438    _file_class = StreamParticleFile
439    _field_info_class = StreamFieldInfo
440    _dataset_type = "stream_particles"
441    file_count = 1
442    filename_template = "stream_file"
443    _proj_type = "particle_proj"
444
445    def __init__(
446        self,
447        stream_handler,
448        storage_filename=None,
449        geometry="cartesian",
450        unit_system="cgs",
451        default_species_fields=None,
452    ):
453        super().__init__(
454            stream_handler,
455            storage_filename=storage_filename,
456            geometry=geometry,
457            unit_system=unit_system,
458            default_species_fields=default_species_fields,
459        )
460        fields = list(stream_handler.fields["stream_file"].keys())
461        # This is the current method of detecting SPH data.
462        # This should be made more flexible in the future.
463        if ("io", "density") in fields and ("io", "smoothing_length") in fields:
464            self._sph_ptypes = ("io",)
465
466    def add_sph_fields(self, n_neighbors=32, kernel="cubic", sph_ptype="io"):
467        """Add SPH fields for the specified particle type.
468
469        For a particle type with "particle_position" and "particle_mass" already
470        defined, this method adds the "smoothing_length" and "density" fields.
471        "smoothing_length" is computed as the distance to the nth nearest
472        neighbor. "density" is computed as the SPH (gather) smoothed mass. The
473        SPH fields are added only if they don't already exist.
474
475        Parameters
476        ----------
477        n_neighbors : int
478            The number of neighbors to use in smoothing length computation.
479        kernel : str
480            The kernel function to use in density estimation.
481        sph_ptype : str
482            The SPH particle type. Each dataset has one sph_ptype only. This
483            method will overwrite existing sph_ptype of the dataset.
484
485        """
486        mylog.info("Generating SPH fields")
487
488        # Unify units
489        l_unit = "code_length"
490        m_unit = "code_mass"
491        d_unit = "code_mass / code_length**3"
492
493        # Read basic fields
494        ad = self.all_data()
495        pos = ad[sph_ptype, "particle_position"].to(l_unit).d
496        mass = ad[sph_ptype, "particle_mass"].to(m_unit).d
497
498        # Construct k-d tree
499        kdtree = PyKDTree(
500            pos.astype("float64"),
501            left_edge=self.domain_left_edge.to_value(l_unit),
502            right_edge=self.domain_right_edge.to_value(l_unit),
503            periodic=self.periodicity,
504            leafsize=2 * int(n_neighbors),
505        )
506        order = np.argsort(kdtree.idx)
507
508        def exists(fname):
509            if (sph_ptype, fname) in self.derived_field_list:
510                mylog.info(
511                    "Field ('%s','%s') already exists. Skipping", sph_ptype, fname
512                )
513                return True
514            else:
515                mylog.info("Generating field ('%s','%s')", sph_ptype, fname)
516                return False
517
518        data = {}
519
520        # Add smoothing length field
521        fname = "smoothing_length"
522        if not exists(fname):
523            hsml = generate_smoothing_length(pos[kdtree.idx], kdtree, n_neighbors)
524            hsml = hsml[order]
525            data[(sph_ptype, "smoothing_length")] = (hsml, l_unit)
526        else:
527            hsml = ad[sph_ptype, fname].to(l_unit).d
528
529        # Add density field
530        fname = "density"
531        if not exists(fname):
532            dens = estimate_density(
533                pos[kdtree.idx],
534                mass[kdtree.idx],
535                hsml[kdtree.idx],
536                kdtree,
537                kernel_name=kernel,
538            )
539            dens = dens[order]
540            data[(sph_ptype, "density")] = (dens, d_unit)
541
542        # Add fields
543        self._sph_ptypes = (sph_ptype,)
544        self.index.update_data(data)
545        self.num_neighbors = n_neighbors
546
547
548_cis = np.fromiter(
549    chain.from_iterable(product([0, 1], [0, 1], [0, 1])), dtype=np.int64, count=8 * 3
550)
551_cis.shape = (8, 3)
552
553
554def hexahedral_connectivity(xgrid, ygrid, zgrid):
555    r"""Define the cell coordinates and cell neighbors of a hexahedral mesh
556    for a semistructured grid. Used to specify the connectivity and
557    coordinates parameters used in
558    :func:`~yt.frontends.stream.data_structures.load_hexahedral_mesh`.
559
560    Parameters
561    ----------
562    xgrid : array_like
563       x-coordinates of boundaries of the hexahedral cells. Should be a
564       one-dimensional array.
565    ygrid : array_like
566       y-coordinates of boundaries of the hexahedral cells. Should be a
567       one-dimensional array.
568    zgrid : array_like
569       z-coordinates of boundaries of the hexahedral cells. Should be a
570       one-dimensional array.
571
572    Returns
573    -------
574    coords : array_like
575        The list of (x,y,z) coordinates of the vertices of the mesh.
576        Is of size (M,3) where M is the number of vertices.
577    connectivity : array_like
578        For each hexahedron h in the mesh, gives the index of each of h's
579        neighbors. Is of size (N,8), where N is the number of hexahedra.
580
581    Examples
582    --------
583
584    >>> xgrid = np.array([-1, -0.25, 0, 0.25, 1])
585    >>> coords, conn = hexahedral_connectivity(xgrid, xgrid, xgrid)
586    >>> coords
587    array([[-1.  , -1.  , -1.  ],
588           [-1.  , -1.  , -0.25],
589           [-1.  , -1.  ,  0.  ],
590           ...,
591           [ 1.  ,  1.  ,  0.  ],
592           [ 1.  ,  1.  ,  0.25],
593           [ 1.  ,  1.  ,  1.  ]])
594
595    >>> conn
596    array([[  0,   1,   5,   6,  25,  26,  30,  31],
597           [  1,   2,   6,   7,  26,  27,  31,  32],
598           [  2,   3,   7,   8,  27,  28,  32,  33],
599           ...,
600           [ 91,  92,  96,  97, 116, 117, 121, 122],
601           [ 92,  93,  97,  98, 117, 118, 122, 123],
602           [ 93,  94,  98,  99, 118, 119, 123, 124]])
603    """
604    nx = len(xgrid)
605    ny = len(ygrid)
606    nz = len(zgrid)
607    coords = np.zeros((nx, ny, nz, 3), dtype="float64", order="C")
608    coords[:, :, :, 0] = xgrid[:, None, None]
609    coords[:, :, :, 1] = ygrid[None, :, None]
610    coords[:, :, :, 2] = zgrid[None, None, :]
611    coords.shape = (nx * ny * nz, 3)
612    cycle = np.rollaxis(np.indices((nx - 1, ny - 1, nz - 1)), 0, 4)
613    cycle.shape = ((nx - 1) * (ny - 1) * (nz - 1), 3)
614    off = _cis + cycle[:, np.newaxis]
615    connectivity = np.array(
616        ((off[:, :, 0] * ny) + off[:, :, 1]) * nz + off[:, :, 2], order="C"
617    )
618    return coords, connectivity
619
620
621class StreamHexahedralMesh(SemiStructuredMesh):
622    _connectivity_length = 8
623    _index_offset = 0
624
625
626class StreamHexahedralHierarchy(UnstructuredIndex):
627    def __init__(self, ds, dataset_type=None):
628        self.stream_handler = ds.stream_handler
629        super().__init__(ds, dataset_type)
630
631    def _initialize_mesh(self):
632        coords = self.stream_handler.fields.pop("coordinates")
633        connect = self.stream_handler.fields.pop("connectivity")
634        self.meshes = [
635            StreamHexahedralMesh(0, self.index_filename, connect, coords, self)
636        ]
637
638    def _setup_data_io(self):
639        if self.stream_handler.io is not None:
640            self.io = self.stream_handler.io
641        else:
642            self.io = io_registry[self.dataset_type](self.ds)
643
644    def _detect_output_fields(self):
645        self.field_list = list(set(self.stream_handler.get_fields()))
646
647
648class StreamHexahedralDataset(StreamDataset):
649    _index_class = StreamHexahedralHierarchy
650    _field_info_class = StreamFieldInfo
651    _dataset_type = "stream_hexahedral"
652
653
654class StreamOctreeSubset(OctreeSubset):
655    domain_id = 1
656    _domain_offset = 1
657
658    def __init__(
659        self, base_region, ds, oct_handler, over_refine_factor=1, num_ghost_zones=0
660    ):
661        self._over_refine_factor = over_refine_factor
662        self._num_zones = 1 << (over_refine_factor)
663        self.field_data = YTFieldData()
664        self.field_parameters = {}
665        self.ds = ds
666        self.oct_handler = oct_handler
667        self._last_mask = None
668        self._last_selector_id = None
669        self._current_particle_type = "io"
670        self._current_fluid_type = self.ds.default_fluid_type
671        self.base_region = base_region
672        self.base_selector = base_region.selector
673
674        self._num_ghost_zones = num_ghost_zones
675
676        if num_ghost_zones > 0:
677            if not all(ds.periodicity):
678                mylog.warning(
679                    "Ghost zones will wrongly assume the domain to be periodic."
680                )
681            base_grid = StreamOctreeSubset(
682                base_region, ds, oct_handler, over_refine_factor
683            )
684            self._base_grid = base_grid
685
686    def retrieve_ghost_zones(self, ngz, fields, smoothed=False):
687        try:
688            new_subset = self._subset_with_gz
689            mylog.debug("Reusing previous subset with ghost zone.")
690        except AttributeError:
691            new_subset = StreamOctreeSubset(
692                self.base_region,
693                self.ds,
694                self.oct_handler,
695                self._over_refine_factor,
696                num_ghost_zones=ngz,
697            )
698            self._subset_with_gz = new_subset
699
700        return new_subset
701
702    def _fill_no_ghostzones(self, content, dest, selector, offset):
703        # Here we get a copy of the file, which we skip through and read the
704        # bits we want.
705        oct_handler = self.oct_handler
706        cell_count = selector.count_oct_cells(self.oct_handler, self.domain_id)
707        levels, cell_inds, file_inds = self.oct_handler.file_index_octs(
708            selector, self.domain_id, cell_count
709        )
710        levels[:] = 0
711        dest.update((field, np.empty(cell_count, dtype="float64")) for field in content)
712        # Make references ...
713        count = oct_handler.fill_level(
714            0, levels, cell_inds, file_inds, dest, content, offset
715        )
716        return count
717
718    def _fill_with_ghostzones(self, content, dest, selector, offset):
719        oct_handler = self.oct_handler
720        ndim = self.ds.dimensionality
721        cell_count = (
722            selector.count_octs(self.oct_handler, self.domain_id) * self.nz ** ndim
723        )
724
725        gz_cache = getattr(self, "_ghost_zone_cache", None)
726        if gz_cache:
727            levels, cell_inds, file_inds, domains = gz_cache
728        else:
729            gz_cache = (
730                levels,
731                cell_inds,
732                file_inds,
733                domains,
734            ) = oct_handler.file_index_octs_with_ghost_zones(
735                selector, self.domain_id, cell_count
736            )
737            self._ghost_zone_cache = gz_cache
738        levels[:] = 0
739        dest.update((field, np.empty(cell_count, dtype="float64")) for field in content)
740        # Make references ...
741        oct_handler.fill_level(0, levels, cell_inds, file_inds, dest, content, offset)
742
743    def fill(self, content, dest, selector, offset):
744        if self._num_ghost_zones == 0:
745            return self._fill_no_ghostzones(content, dest, selector, offset)
746        else:
747            return self._fill_with_ghostzones(content, dest, selector, offset)
748
749
750class StreamOctreeHandler(OctreeIndex):
751    def __init__(self, ds, dataset_type=None):
752        self.stream_handler = ds.stream_handler
753        self.dataset_type = dataset_type
754        super().__init__(ds, dataset_type)
755
756    def _setup_data_io(self):
757        if self.stream_handler.io is not None:
758            self.io = self.stream_handler.io
759        else:
760            self.io = io_registry[self.dataset_type](self.ds)
761
762    def _initialize_oct_handler(self):
763        header = dict(
764            dims=[1, 1, 1],
765            left_edge=self.ds.domain_left_edge,
766            right_edge=self.ds.domain_right_edge,
767            octree=self.ds.octree_mask,
768            over_refine=self.ds.over_refine_factor,
769            partial_coverage=self.ds.partial_coverage,
770        )
771        self.oct_handler = OctreeContainer.load_octree(header)
772
773    def _identify_base_chunk(self, dobj):
774        if getattr(dobj, "_chunk_info", None) is None:
775            base_region = getattr(dobj, "base_region", dobj)
776            subset = [
777                StreamOctreeSubset(
778                    base_region,
779                    self.dataset,
780                    self.oct_handler,
781                    self.ds.over_refine_factor,
782                )
783            ]
784            dobj._chunk_info = subset
785        dobj._current_chunk = list(self._chunk_all(dobj))[0]
786
787    def _chunk_all(self, dobj):
788        oobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info)
789        yield YTDataChunk(dobj, "all", oobjs, None)
790
791    def _chunk_spatial(self, dobj, ngz, sort=None, preload_fields=None):
792        sobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info)
793        # This is where we will perform cutting of the Octree and
794        # load-balancing.  That may require a specialized selector object to
795        # cut based on some space-filling curve index.
796        for og in sobjs:
797            if ngz > 0:
798                g = og.retrieve_ghost_zones(ngz, [], smoothed=True)
799            else:
800                g = og
801            yield YTDataChunk(dobj, "spatial", [g])
802
803    def _chunk_io(self, dobj, cache=True, local_only=False):
804        oobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info)
805        for subset in oobjs:
806            yield YTDataChunk(dobj, "io", [subset], None, cache=cache)
807
808    def _setup_classes(self):
809        dd = self._get_data_reader_dict()
810        super()._setup_classes(dd)
811
812    def _detect_output_fields(self):
813        # NOTE: Because particle unions add to the actual field list, without
814        # having the keys in the field list itself, we need to double check
815        # here.
816        fl = set(self.stream_handler.get_fields())
817        fl.update(set(getattr(self, "field_list", [])))
818        self.field_list = list(fl)
819
820
821class StreamOctreeDataset(StreamDataset):
822    _index_class = StreamOctreeHandler
823    _field_info_class = StreamFieldInfo
824    _dataset_type = "stream_octree"
825
826    levelmax = None
827
828    def __init__(
829        self,
830        stream_handler,
831        storage_filename=None,
832        geometry="cartesian",
833        unit_system="cgs",
834        default_species_fields=None,
835    ):
836        super().__init__(
837            stream_handler,
838            storage_filename,
839            geometry,
840            unit_system,
841            default_species_fields=default_species_fields,
842        )
843        # Set up levelmax
844        self.max_level = stream_handler.levels.max()
845        self.min_level = stream_handler.levels.min()
846
847
848class StreamUnstructuredMesh(UnstructuredMesh):
849    _index_offset = 0
850
851    def __init__(self, *args, **kwargs):
852        super().__init__(*args, **kwargs)
853        self._connectivity_length = self.connectivity_indices.shape[1]
854
855
856class StreamUnstructuredIndex(UnstructuredIndex):
857    def __init__(self, ds, dataset_type=None):
858        self.stream_handler = ds.stream_handler
859        super().__init__(ds, dataset_type)
860
861    def _initialize_mesh(self):
862        coords = self.stream_handler.fields.pop("coordinates")
863        connect = always_iterable(self.stream_handler.fields.pop("connectivity"))
864
865        self.meshes = [
866            StreamUnstructuredMesh(i, self.index_filename, c1, c2, self)
867            for i, (c1, c2) in enumerate(zip(connect, repeat(coords)))
868        ]
869        self.mesh_union = MeshUnion("mesh_union", self.meshes)
870
871    def _setup_data_io(self):
872        if self.stream_handler.io is not None:
873            self.io = self.stream_handler.io
874        else:
875            self.io = io_registry[self.dataset_type](self.ds)
876
877    def _detect_output_fields(self):
878        self.field_list = list(set(self.stream_handler.get_fields()))
879        fnames = list({fn for ft, fn in self.field_list})
880        self.field_list += [("all", fname) for fname in fnames]
881
882
883class StreamUnstructuredMeshDataset(StreamDataset):
884    _index_class = StreamUnstructuredIndex
885    _field_info_class = StreamFieldInfo
886    _dataset_type = "stream_unstructured"
887
888    def _find_particle_types(self):
889        pass
890