1import os
2import weakref
3from collections import defaultdict
4from numbers import Number as numeric_type
5
6import numpy as np
7
8from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
9from yt.data_objects.particle_unions import ParticleUnion
10from yt.data_objects.profiles import (
11    Profile1DFromDataset,
12    Profile2DFromDataset,
13    Profile3DFromDataset,
14)
15from yt.data_objects.static_output import Dataset, ParticleFile, validate_index_order
16from yt.fields.field_exceptions import NeedsGridType
17from yt.funcs import is_root, parse_h5_attr
18from yt.geometry.grid_geometry_handler import GridIndex
19from yt.geometry.particle_geometry_handler import ParticleIndex
20from yt.units import dimensions
21from yt.units.unit_registry import UnitRegistry
22from yt.units.yt_array import YTQuantity, uconcatenate
23from yt.utilities.exceptions import GenerationInProgress, YTFieldTypeNotFound
24from yt.utilities.logger import ytLogger as mylog
25from yt.utilities.on_demand_imports import _h5py as h5py
26from yt.utilities.parallel_tools.parallel_analysis_interface import parallel_root_only
27from yt.utilities.tree_container import TreeContainer
28
29from .fields import YTDataContainerFieldInfo, YTGridFieldInfo
30
31_grid_data_containers = ["arbitrary_grid", "covering_grid", "smoothed_covering_grid"]
32_set_attrs = {"periodicity": "_periodicity"}
33
34
35class SavedDataset(Dataset):
36    """
37    Base dataset class for products of calling save_as_dataset.
38    """
39
40    _con_attrs = ()
41
42    def _parse_parameter_file(self):
43        self.refine_by = 2
44        with h5py.File(self.parameter_filename, mode="r") as f:
45            for key in f.attrs.keys():
46                v = parse_h5_attr(f, key)
47                if key == "con_args":
48                    try:
49                        v = eval(v)
50                    except ValueError:
51                        # support older ytdata outputs
52                        v = v.astype("str")
53                    except NameError:
54                        # This is the most common error we expect, and it
55                        # results from having the eval do a concatenated decoded
56                        # set of the values.
57                        v = [_.decode("utf8") for _ in v]
58                self.parameters[key] = v
59            self._with_parameter_file_open(f)
60
61        # if saved, restore unit registry from the json string
62        if "unit_registry_json" in self.parameters:
63            self.unit_registry = UnitRegistry.from_json(
64                self.parameters["unit_registry_json"]
65            )
66            # reset self.arr and self.quan to use new unit_registry
67            self._arr = None
68            self._quan = None
69            for dim in [
70                "length",
71                "mass",
72                "pressure",
73                "temperature",
74                "time",
75                "velocity",
76            ]:
77                cu = "code_" + dim
78                if cu not in self.unit_registry:
79                    self.unit_registry.add(cu, 1.0, getattr(dimensions, dim))
80            if "code_magnetic" not in self.unit_registry:
81                self.unit_registry.add("code_magnetic", 1.0, dimensions.magnetic_field)
82
83        # if saved, set unit system
84        if "unit_system_name" in self.parameters:
85            unit_system = self.parameters["unit_system_name"]
86            del self.parameters["unit_system_name"]
87        else:
88            unit_system = "cgs"
89        # reset unit system since we may have a new unit registry
90        self._assign_unit_system(unit_system)
91
92        # assign units to parameters that have associated unit string
93        del_pars = []
94        for par in self.parameters:
95            ustr = f"{par}_units"
96            if ustr in self.parameters:
97                if isinstance(self.parameters[par], np.ndarray):
98                    to_u = self.arr
99                else:
100                    to_u = self.quan
101                self.parameters[par] = to_u(self.parameters[par], self.parameters[ustr])
102                del_pars.append(ustr)
103        for par in del_pars:
104            del self.parameters[par]
105
106        for attr in self._con_attrs:
107            try:
108                sattr = _set_attrs.get(attr, attr)
109                setattr(self, sattr, self.parameters.get(attr))
110            except TypeError:
111                # some Dataset attributes are properties with setters
112                # which may not accept None as an input
113                pass
114
115        if self.geometry is None:
116            self.geometry = "cartesian"
117
118    def _with_parameter_file_open(self, f):
119        # This allows subclasses to access the parameter file
120        # while it's still open to get additional information.
121        pass
122
123    def set_units(self):
124        if "unit_registry_json" in self.parameters:
125            self._set_code_unit_attributes()
126            del self.parameters["unit_registry_json"]
127        else:
128            super().set_units()
129
130    def _set_code_unit_attributes(self):
131        attrs = (
132            "length_unit",
133            "mass_unit",
134            "time_unit",
135            "velocity_unit",
136            "magnetic_unit",
137        )
138        cgs_units = ("cm", "g", "s", "cm/s", "gauss")
139        base_units = np.ones(len(attrs))
140        for unit, attr, cgs_unit in zip(base_units, attrs, cgs_units):
141            if attr in self.parameters and isinstance(
142                self.parameters[attr], YTQuantity
143            ):
144                uq = self.parameters[attr]
145            elif attr in self.parameters and f"{attr}_units" in self.parameters:
146                uq = self.quan(self.parameters[attr], self.parameters[f"{attr}_units"])
147                del self.parameters[attr]
148                del self.parameters[f"{attr}_units"]
149            elif isinstance(unit, str):
150                uq = self.quan(1.0, unit)
151            elif isinstance(unit, numeric_type):
152                uq = self.quan(unit, cgs_unit)
153            elif isinstance(unit, YTQuantity):
154                uq = unit
155            elif isinstance(unit, tuple):
156                uq = self.quan(unit[0], unit[1])
157            else:
158                raise RuntimeError(f"{attr} ({unit}) is invalid.")
159            setattr(self, attr, uq)
160
161
162class YTDataset(SavedDataset):
163    """Base dataset class for all ytdata datasets."""
164
165    _con_attrs = (
166        "cosmological_simulation",
167        "current_time",
168        "current_redshift",
169        "hubble_constant",
170        "omega_matter",
171        "omega_lambda",
172        "dimensionality",
173        "domain_dimensions",
174        "geometry",
175        "periodicity",
176        "domain_left_edge",
177        "domain_right_edge",
178        "container_type",
179        "data_type",
180    )
181
182    def _with_parameter_file_open(self, f):
183        self.num_particles = {
184            group: parse_h5_attr(f[group], "num_elements")
185            for group in f
186            if group != self.default_fluid_type
187        }
188
189    def create_field_info(self):
190        self.field_dependencies = {}
191        self.derived_field_list = []
192        self.filtered_particle_types = []
193        self.field_info = self._field_info_class(self, self.field_list)
194        self.coordinates.setup_fields(self.field_info)
195        self.field_info.setup_fluid_fields()
196        for ptype in self.particle_types:
197            self.field_info.setup_particle_fields(ptype)
198
199        self._setup_gas_alias()
200        self.field_info.setup_fluid_index_fields()
201
202        if "all" not in self.particle_types:
203            mylog.debug("Creating Particle Union 'all'")
204            pu = ParticleUnion("all", list(self.particle_types_raw))
205            self.add_particle_union(pu)
206        self.field_info.setup_extra_union_fields()
207        mylog.debug("Loading field plugins.")
208        self.field_info.load_all_plugins()
209        deps, unloaded = self.field_info.check_derived_fields()
210        self.field_dependencies.update(deps)
211
212    def _setup_gas_alias(self):
213        pass
214
215    def _setup_override_fields(self):
216        pass
217
218
219class YTDataHDF5File(ParticleFile):
220    def __init__(self, ds, io, filename, file_id, range):
221        with h5py.File(filename, mode="r") as f:
222            self.header = {field: parse_h5_attr(f, field) for field in f.attrs.keys()}
223
224        super().__init__(ds, io, filename, file_id, range)
225
226
227class YTDataContainerDataset(YTDataset):
228    """Dataset for saved geometric data containers."""
229
230    _index_class = ParticleIndex
231    _file_class = YTDataHDF5File
232    _field_info_class = YTDataContainerFieldInfo
233    _suffix = ".h5"
234    fluid_types = ("grid", "gas", "deposit", "index")
235
236    def __init__(
237        self,
238        filename,
239        dataset_type="ytdatacontainer_hdf5",
240        index_order=None,
241        index_filename=None,
242        units_override=None,
243        unit_system="cgs",
244    ):
245        self.index_order = validate_index_order(index_order)
246        self.index_filename = index_filename
247        super().__init__(
248            filename,
249            dataset_type,
250            units_override=units_override,
251            unit_system=unit_system,
252        )
253
254    def _parse_parameter_file(self):
255        super()._parse_parameter_file()
256        self.particle_types_raw = tuple(self.num_particles.keys())
257        self.particle_types = self.particle_types_raw
258        self.filename_template = self.parameter_filename
259        self.file_count = 1
260        self.domain_dimensions = np.ones(3, "int32")
261
262    def _setup_gas_alias(self):
263        "Alias the grid type to gas by making a particle union."
264
265        if "grid" in self.particle_types and "gas" not in self.particle_types:
266            pu = ParticleUnion("gas", ["grid"])
267            self.add_particle_union(pu)
268        # We have to alias this because particle unions only
269        # cover the field_list.
270        self.field_info.alias(("gas", "cell_volume"), ("grid", "cell_volume"))
271
272    _data_obj = None
273
274    @property
275    def data(self):
276        """
277        Return a data container configured like the original used to
278        create this dataset.
279        """
280
281        if self._data_obj is None:
282            # Some data containers can't be reconstructed in the same way
283            # since this is now particle-like data.
284            data_type = self.parameters.get("data_type")
285            container_type = self.parameters.get("container_type")
286            ex_container_type = ["cutting", "quad_proj", "ray", "slice", "cut_region"]
287            if data_type == "yt_light_ray" or container_type in ex_container_type:
288                mylog.info("Returning an all_data data container.")
289                return self.all_data()
290
291            my_obj = getattr(self, self.parameters["container_type"])
292            my_args = [
293                self.parameters[con_arg] for con_arg in self.parameters["con_args"]
294            ]
295            self._data_obj = my_obj(*my_args)
296        return self._data_obj
297
298    @classmethod
299    def _is_valid(cls, filename, *args, **kwargs):
300        if not filename.endswith(".h5"):
301            return False
302        with h5py.File(filename, mode="r") as f:
303            data_type = parse_h5_attr(f, "data_type")
304            cont_type = parse_h5_attr(f, "container_type")
305            if data_type is None:
306                return False
307            if (
308                data_type == "yt_data_container"
309                and cont_type not in _grid_data_containers
310            ):
311                return True
312        return False
313
314
315class YTDataLightRayDataset(YTDataContainerDataset):
316    """Dataset for saved LightRay objects."""
317
318    def _parse_parameter_file(self):
319        super()._parse_parameter_file()
320        self._restore_light_ray_solution()
321
322    def _restore_light_ray_solution(self):
323        """
324        Restore all information associate with the light ray solution
325        to its original form.
326        """
327        key = "light_ray_solution"
328        self.light_ray_solution = []
329        lrs_fields = [
330            par for par in self.parameters if key in par and not par.endswith("_units")
331        ]
332        if len(lrs_fields) == 0:
333            return
334        self.light_ray_solution = [{} for val in self.parameters[lrs_fields[0]]]
335        for sp3 in ["unique_identifier", "filename"]:
336            ksp3 = f"{key}_{sp3}"
337            if ksp3 not in lrs_fields:
338                continue
339            self.parameters[ksp3] = self.parameters[ksp3].astype(str)
340        for field in lrs_fields:
341            field_name = field[len(key) + 1 :]
342            for i in range(self.parameters[field].shape[0]):
343                self.light_ray_solution[i][field_name] = self.parameters[field][i]
344
345    @classmethod
346    def _is_valid(cls, filename, *args, **kwargs):
347        if not filename.endswith(".h5"):
348            return False
349        with h5py.File(filename, mode="r") as f:
350            data_type = parse_h5_attr(f, "data_type")
351            if data_type in ["yt_light_ray"]:
352                return True
353        return False
354
355
356class YTSpatialPlotDataset(YTDataContainerDataset):
357    """Dataset for saved slices and projections."""
358
359    _field_info_class = YTGridFieldInfo
360
361    def __init__(self, *args, **kwargs):
362        super().__init__(*args, dataset_type="ytspatialplot_hdf5", **kwargs)
363
364    def _parse_parameter_file(self):
365        super()._parse_parameter_file()
366        if self.parameters["container_type"] == "proj":
367            if (
368                isinstance(self.parameters["weight_field"], str)
369                and self.parameters["weight_field"] == "None"
370            ):
371                self.parameters["weight_field"] = None
372            elif isinstance(self.parameters["weight_field"], np.ndarray):
373                self.parameters["weight_field"] = tuple(self.parameters["weight_field"])
374
375    @classmethod
376    def _is_valid(cls, filename, *args, **kwargs):
377        if not filename.endswith(".h5"):
378            return False
379        with h5py.File(filename, mode="r") as f:
380            data_type = parse_h5_attr(f, "data_type")
381            cont_type = parse_h5_attr(f, "container_type")
382            if data_type == "yt_data_container" and cont_type in [
383                "cutting",
384                "proj",
385                "slice",
386                "quad_proj",
387            ]:
388                return True
389        return False
390
391
392class YTGrid(AMRGridPatch):
393    _id_offset = 0
394
395    def __init__(self, gid, index, filename=None):
396        AMRGridPatch.__init__(self, gid, filename=filename, index=index)
397        self._children_ids = []
398        self._parent_id = -1
399        self.Level = 0
400        self.LeftEdge = self.index.ds.domain_left_edge
401        self.RightEdge = self.index.ds.domain_right_edge
402
403    def __getitem__(self, key):
404        tr = super(AMRGridPatch, self).__getitem__(key)
405        try:
406            fields = self._determine_fields(key)
407        except YTFieldTypeNotFound:
408            return tr
409        finfo = self.ds._get_field_info(*fields[0])
410        if not finfo.sampling_type == "particle":
411            return tr.reshape(self.ActiveDimensions[: self.ds.dimensionality])
412        return tr
413
414    @property
415    def Parent(self):
416        return None
417
418    @property
419    def Children(self):
420        return []
421
422
423class YTDataHierarchy(GridIndex):
424    def __init__(self, ds, dataset_type=None):
425        self.dataset_type = dataset_type
426        self.float_type = "float64"
427        self.dataset = weakref.proxy(ds)
428        self.directory = os.getcwd()
429        super().__init__(ds, dataset_type)
430
431    def _count_grids(self):
432        self.num_grids = 1
433
434    def _parse_index(self):
435        self.grid_dimensions[:] = self.ds.domain_dimensions
436        self.grid_left_edge[:] = self.ds.domain_left_edge
437        self.grid_right_edge[:] = self.ds.domain_right_edge
438        self.grid_levels[:] = np.zeros(self.num_grids)
439        self.grid_procs = np.zeros(self.num_grids)
440        self.grid_particle_count[:] = sum(self.ds.num_particles.values())
441        self.grids = []
442        for gid in range(self.num_grids):
443            self.grids.append(self.grid(gid, self))
444            self.grids[gid].Level = self.grid_levels[gid, 0]
445        self.max_level = self.grid_levels.max()
446        temp_grids = np.empty(self.num_grids, dtype="object")
447        for i, grid in enumerate(self.grids):
448            grid.filename = self.ds.parameter_filename
449            grid._prepare_grid()
450            grid.proc_num = self.grid_procs[i]
451            temp_grids[i] = grid
452        self.grids = temp_grids
453
454    def _detect_output_fields(self):
455        self.field_list = []
456        self.ds.field_units = self.ds.field_units or {}
457        with h5py.File(self.ds.parameter_filename, mode="r") as f:
458            for group in f:
459                for field in f[group]:
460                    field_name = (str(group), str(field))
461                    self.field_list.append(field_name)
462                    self.ds.field_units[field_name] = parse_h5_attr(
463                        f[group][field], "units"
464                    )
465
466
467class YTGridHierarchy(YTDataHierarchy):
468    grid = YTGrid
469
470    def _populate_grid_objects(self):
471        for g in self.grids:
472            g._setup_dx()
473        self.max_level = self.grid_levels.max()
474
475
476class YTGridDataset(YTDataset):
477    """Dataset for saved covering grids, arbitrary grids, and FRBs."""
478
479    _index_class = YTGridHierarchy
480    _field_info_class = YTGridFieldInfo
481    _dataset_type = "ytgridhdf5"
482    geometry = "cartesian"
483    default_fluid_type = "grid"
484    fluid_types = ("grid", "gas", "deposit", "index")
485
486    def __init__(self, filename, unit_system="cgs"):
487        super().__init__(filename, self._dataset_type, unit_system=unit_system)
488        self.data = self.index.grids[0]
489
490    def _parse_parameter_file(self):
491        super()._parse_parameter_file()
492        self.num_particles.pop(self.default_fluid_type, None)
493        self.particle_types_raw = tuple(self.num_particles.keys())
494        self.particle_types = self.particle_types_raw
495
496        # correct domain dimensions for the covering grid dimension
497        self.base_domain_left_edge = self.domain_left_edge
498        self.base_domain_right_edge = self.domain_right_edge
499        self.base_domain_dimensions = self.domain_dimensions
500
501        if self.container_type in _grid_data_containers:
502            self.domain_left_edge = self.parameters["left_edge"]
503
504            if "level" in self.parameters["con_args"]:
505                dx = (self.base_domain_right_edge - self.base_domain_left_edge) / (
506                    self.domain_dimensions * self.refine_by ** self.parameters["level"]
507                )
508                self.domain_right_edge = (
509                    self.domain_left_edge + self.parameters["ActiveDimensions"] * dx
510                )
511                self.domain_dimensions = (
512                    (self.domain_right_edge - self.domain_left_edge) / dx
513                ).astype(int)
514            else:
515                self.domain_right_edge = self.parameters["right_edge"]
516                self.domain_dimensions = self.parameters["ActiveDimensions"]
517                dx = (
518                    self.domain_right_edge - self.domain_left_edge
519                ) / self.domain_dimensions
520
521            periodicity = (
522                np.abs(self.domain_left_edge - self.base_domain_left_edge) < 0.5 * dx
523            )
524            periodicity &= (
525                np.abs(self.domain_right_edge - self.base_domain_right_edge) < 0.5 * dx
526            )
527            self._periodicity = periodicity
528
529        elif self.data_type == "yt_frb":
530            dle = self.domain_left_edge
531            self.domain_left_edge = uconcatenate(
532                [self.parameters["left_edge"].to(dle.units), [0] * dle.uq]
533            )
534            dre = self.domain_right_edge
535            self.domain_right_edge = uconcatenate(
536                [self.parameters["right_edge"].to(dre.units), [1] * dre.uq]
537            )
538            self.domain_dimensions = np.concatenate(
539                [self.parameters["ActiveDimensions"], [1]]
540            )
541
542    def _setup_gas_alias(self):
543        "Alias the grid type to gas with a field alias."
544
545        for ftype, field in self.field_list:
546            if ftype == "grid":
547                self.field_info.alias(("gas", field), ("grid", field))
548
549    @classmethod
550    def _is_valid(cls, filename, *args, **kwargs):
551        if not filename.endswith(".h5"):
552            return False
553        with h5py.File(filename, mode="r") as f:
554            data_type = parse_h5_attr(f, "data_type")
555            cont_type = parse_h5_attr(f, "container_type")
556            if data_type == "yt_frb":
557                return True
558            if data_type == "yt_data_container" and cont_type in _grid_data_containers:
559                return True
560        return False
561
562
563class YTNonspatialGrid(AMRGridPatch):
564    _id_offset = 0
565
566    def __init__(self, gid, index, filename=None):
567        super().__init__(gid, filename=filename, index=index)
568        self._children_ids = []
569        self._parent_id = -1
570        self.Level = 0
571        self.LeftEdge = self.index.ds.domain_left_edge
572        self.RightEdge = self.index.ds.domain_right_edge
573
574    def __repr__(self):
575        return "YTNonspatialGrid"
576
577    def __getitem__(self, key):
578        tr = super(AMRGridPatch, self).__getitem__(key)
579        try:
580            fields = self._determine_fields(key)
581        except YTFieldTypeNotFound:
582            return tr
583        self.ds._get_field_info(*fields[0])
584        return tr
585
586    def get_data(self, fields=None):
587        if fields is None:
588            return
589        nfields = []
590        apply_fields = defaultdict(list)
591        for field in self._determine_fields(fields):
592            if field[0] in self.ds.filtered_particle_types:
593                f = self.ds.known_filters[field[0]]
594                apply_fields[field[0]].append((f.filtered_type, field[1]))
595            else:
596                nfields.append(field)
597        for filter_type in apply_fields:
598            f = self.ds.known_filters[filter_type]
599            with f.apply(self):
600                self.get_data(apply_fields[filter_type])
601        fields = nfields
602        if len(fields) == 0:
603            return
604        # Now we collect all our fields
605        # Here is where we need to perform a validation step, so that if we
606        # have a field requested that we actually *can't* yet get, we put it
607        # off until the end.  This prevents double-reading fields that will
608        # need to be used in spatial fields later on.
609        fields_to_get = []
610        # This will be pre-populated with spatial fields
611        fields_to_generate = []
612        for field in self._determine_fields(fields):
613            if field in self.field_data:
614                continue
615            finfo = self.ds._get_field_info(*field)
616            try:
617                finfo.check_available(self)
618            except NeedsGridType:
619                fields_to_generate.append(field)
620                continue
621            fields_to_get.append(field)
622        if len(fields_to_get) == 0 and len(fields_to_generate) == 0:
623            return
624        elif self._locked:
625            raise GenerationInProgress(fields)
626        # Track which ones we want in the end
627        ofields = set(list(self.field_data.keys()) + fields_to_get + fields_to_generate)
628        # At this point, we want to figure out *all* our dependencies.
629        fields_to_get = self._identify_dependencies(fields_to_get, self._spatial)
630        # We now split up into readers for the types of fields
631        fluids, particles = [], []
632        finfos = {}
633        for ftype, fname in fields_to_get:
634            finfo = self.ds._get_field_info(ftype, fname)
635            finfos[ftype, fname] = finfo
636            if finfo.sampling_type == "particle":
637                particles.append((ftype, fname))
638            elif (ftype, fname) not in fluids:
639                fluids.append((ftype, fname))
640
641        # The _read method will figure out which fields it needs to get from
642        # disk, and return a dict of those fields along with the fields that
643        # need to be generated.
644        read_fluids, gen_fluids = self.index._read_fluid_fields(
645            fluids, self, self._current_chunk
646        )
647        for f, v in read_fluids.items():
648            convert = True
649            if v.dtype != np.float64:
650                if finfos[f].units == "":
651                    self.field_data[f] = v
652                    convert = False
653                else:
654                    v = v.astype(np.float64)
655            if convert:
656                self.field_data[f] = self.ds.arr(v, units=finfos[f].units)
657                self.field_data[f].convert_to_units(finfos[f].output_units)
658
659        read_particles, gen_particles = self.index._read_fluid_fields(
660            particles, self, self._current_chunk
661        )
662        for f, v in read_particles.items():
663            convert = True
664            if v.dtype != np.float64:
665                if finfos[f].units == "":
666                    self.field_data[f] = v
667                    convert = False
668                else:
669                    v = v.astype(np.float64)
670            if convert:
671                self.field_data[f] = self.ds.arr(v, units=finfos[f].units)
672                self.field_data[f].convert_to_units(finfos[f].output_units)
673
674        fields_to_generate += gen_fluids + gen_particles
675        self._generate_fields(fields_to_generate)
676        for field in list(self.field_data.keys()):
677            if field not in ofields:
678                self.field_data.pop(field)
679
680    @property
681    def Parent(self):
682        return None
683
684    @property
685    def Children(self):
686        return []
687
688
689class YTNonspatialHierarchy(YTDataHierarchy):
690    grid = YTNonspatialGrid
691
692    def _populate_grid_objects(self):
693        for g in self.grids:
694            g._setup_dx()
695            # this is non-spatial, so remove the code_length units
696            g.dds = self.ds.arr(g.dds.d, "")
697            g.ActiveDimensions = self.ds.domain_dimensions
698        self.max_level = self.grid_levels.max()
699
700    def _read_fluid_fields(self, fields, dobj, chunk=None):
701        if len(fields) == 0:
702            return {}, []
703        fields_to_read, fields_to_generate = self._split_fields(fields)
704        if len(fields_to_read) == 0:
705            return {}, fields_to_generate
706        selector = dobj.selector
707        fields_to_return = self.io._read_fluid_selection(dobj, selector, fields_to_read)
708        return fields_to_return, fields_to_generate
709
710
711class YTNonspatialDataset(YTGridDataset):
712    """Dataset for general array data."""
713
714    _index_class = YTNonspatialHierarchy
715    _field_info_class = YTGridFieldInfo
716    _dataset_type = "ytnonspatialhdf5"
717    geometry = "cartesian"
718    default_fluid_type = "data"
719    fluid_types = ("data", "gas")
720
721    def _parse_parameter_file(self):
722        super(YTGridDataset, self)._parse_parameter_file()
723        self.num_particles.pop(self.default_fluid_type, None)
724        self.particle_types_raw = tuple(self.num_particles.keys())
725        self.particle_types = self.particle_types_raw
726
727    def _set_derived_attrs(self):
728        # set some defaults just to make things go
729        default_attrs = {
730            "dimensionality": 3,
731            "domain_dimensions": np.ones(3, dtype="int64"),
732            "domain_left_edge": np.zeros(3),
733            "domain_right_edge": np.ones(3),
734            "_periodicity": np.ones(3, dtype="bool"),
735        }
736        for att, val in default_attrs.items():
737            if getattr(self, att, None) is None:
738                setattr(self, att, val)
739
740    def _setup_classes(self):
741        # We don't allow geometric selection for non-spatial datasets
742        self.objects = []
743
744    @parallel_root_only
745    def print_key_parameters(self):
746        for a in [
747            "current_time",
748            "domain_dimensions",
749            "domain_left_edge",
750            "domain_right_edge",
751            "cosmological_simulation",
752        ]:
753            v = getattr(self, a)
754            if v is not None:
755                mylog.info("Parameters: %-25s = %s", a, v)
756        if hasattr(self, "cosmological_simulation") and self.cosmological_simulation:
757            for a in [
758                "current_redshift",
759                "omega_lambda",
760                "omega_matter",
761                "hubble_constant",
762            ]:
763                v = getattr(self, a)
764                if v is not None:
765                    mylog.info("Parameters: %-25s = %s", a, v)
766        mylog.warning("Geometric data selection not available for this dataset type.")
767
768    @classmethod
769    def _is_valid(cls, filename, *args, **kwargs):
770        if not filename.endswith(".h5"):
771            return False
772        with h5py.File(filename, mode="r") as f:
773            data_type = parse_h5_attr(f, "data_type")
774            if data_type == "yt_array_data":
775                return True
776        return False
777
778
779class YTProfileDataset(YTNonspatialDataset):
780    """Dataset for saved profile objects."""
781
782    fluid_types = ("data", "gas", "standard_deviation")
783
784    def __init__(self, filename, unit_system="cgs"):
785        super().__init__(filename, unit_system=unit_system)
786
787    _profile = None
788
789    @property
790    def profile(self):
791        if self._profile is not None:
792            return self._profile
793        if self.dimensionality == 1:
794            self._profile = Profile1DFromDataset(self)
795        elif self.dimensionality == 2:
796            self._profile = Profile2DFromDataset(self)
797        elif self.dimensionality == 3:
798            self._profile = Profile3DFromDataset(self)
799        else:
800            self._profile = None
801        return self._profile
802
803    def _parse_parameter_file(self):
804        super(YTGridDataset, self)._parse_parameter_file()
805
806        if (
807            isinstance(self.parameters["weight_field"], str)
808            and self.parameters["weight_field"] == "None"
809        ):
810            self.parameters["weight_field"] = None
811        elif isinstance(self.parameters["weight_field"], np.ndarray):
812            self.parameters["weight_field"] = tuple(
813                self.parameters["weight_field"].astype(str)
814            )
815
816        for a in ["profile_dimensions"] + [
817            f"{ax}_{attr}" for ax in "xyz"[: self.dimensionality] for attr in ["log"]
818        ]:
819            setattr(self, a, self.parameters[a])
820
821        self.base_domain_left_edge = self.domain_left_edge
822        self.base_domain_right_edge = self.domain_right_edge
823        self.base_domain_dimensions = self.domain_dimensions
824
825        domain_dimensions = np.ones(3, dtype="int64")
826        domain_dimensions[: self.dimensionality] = self.profile_dimensions
827        self.domain_dimensions = domain_dimensions
828        domain_left_edge = np.zeros(3)
829        domain_right_edge = np.ones(3)
830        for i, ax in enumerate("xyz"[: self.dimensionality]):
831            range_name = f"{ax}_range"
832            my_range = self.parameters[range_name]
833            if getattr(self, f"{ax}_log", False):
834                my_range = np.log10(my_range)
835            domain_left_edge[i] = my_range[0]
836            domain_right_edge[i] = my_range[1]
837            setattr(self, range_name, self.parameters[range_name])
838
839            bin_field = f"{ax}_field"
840            if (
841                isinstance(self.parameters[bin_field], str)
842                and self.parameters[bin_field] == "None"
843            ):
844                self.parameters[bin_field] = None
845            elif isinstance(self.parameters[bin_field], np.ndarray):
846                self.parameters[bin_field] = tuple(
847                    ["data", self.parameters[bin_field].astype(str)[1]]
848                )
849            setattr(self, bin_field, self.parameters[bin_field])
850        self.domain_left_edge = domain_left_edge
851        self.domain_right_edge = domain_right_edge
852
853    def _setup_gas_alias(self):
854        "Alias the grid type to gas with a field alias."
855        for ftype, field in self.field_list:
856            if ftype == "data":
857                self.field_info.alias(("gas", field), (ftype, field))
858
859    def create_field_info(self):
860        super().create_field_info()
861        if self.parameters["weight_field"] is not None:
862            self.field_info.alias(
863                self.parameters["weight_field"], (self.default_fluid_type, "weight")
864            )
865
866    def _set_derived_attrs(self):
867        self.domain_center = 0.5 * (self.domain_right_edge + self.domain_left_edge)
868        self.domain_width = self.domain_right_edge - self.domain_left_edge
869
870    def print_key_parameters(self):
871        if is_root():
872            mylog.info("YTProfileDataset")
873            for a in ["dimensionality", "profile_dimensions"] + [
874                f"{ax}_{attr}"
875                for ax in "xyz"[: self.dimensionality]
876                for attr in ["field", "range", "log"]
877            ]:
878                v = getattr(self, a)
879                mylog.info("Parameters: %-25s = %s", a, v)
880        super().print_key_parameters()
881
882    @classmethod
883    def _is_valid(cls, filename, *args, **kwargs):
884        if not filename.endswith(".h5"):
885            return False
886        with h5py.File(filename, mode="r") as f:
887            data_type = parse_h5_attr(f, "data_type")
888            if data_type == "yt_profile":
889                return True
890        return False
891
892
893class YTClumpContainer(TreeContainer):
894    def __init__(
895        self, clump_id, global_id, parent_id, contour_key, contour_id, ds=None
896    ):
897        self.clump_id = clump_id
898        self.global_id = global_id
899        self.parent_id = parent_id
900        self.contour_key = contour_key
901        self.contour_id = contour_id
902        self.parent = None
903        self.ds = ds
904        TreeContainer.__init__(self)
905
906    def add_child(self, child):
907        if self.children is None:
908            self.children = []
909        self.children.append(child)
910        child.parent = self
911
912    def __repr__(self):
913        return "Clump[%d]" % self.clump_id
914
915    def __getitem__(self, field):
916        g = self.ds.data
917        f = g._determine_fields(field)[0]
918        if f[0] == "clump":
919            return g[f][self.global_id]
920        if self.contour_id == -1:
921            return g[f]
922        cfield = (f[0], f"contours_{self.contour_key.decode('utf-8')}")
923        if f[0] == "grid":
924            return g[f][g[cfield] == self.contour_id]
925        return self.parent[f][g[cfield] == self.contour_id]
926
927
928class YTClumpTreeDataset(YTNonspatialDataset):
929    """Dataset for saved clump-finder data."""
930
931    def __init__(self, filename, unit_system="cgs"):
932        super().__init__(filename, unit_system=unit_system)
933        self._load_tree()
934
935    def _load_tree(self):
936        my_tree = {}
937        for i, clump_id in enumerate(self.data[("clump", "clump_id")]):
938            my_tree[clump_id] = YTClumpContainer(
939                clump_id,
940                i,
941                self.data["clump", "parent_id"][i],
942                self.data["clump", "contour_key"][i],
943                self.data["clump", "contour_id"][i],
944                self,
945            )
946        for clump in my_tree.values():
947            if clump.parent_id == -1:
948                self.tree = clump
949            else:
950                parent = my_tree[clump.parent_id]
951                parent.add_child(clump)
952
953    _leaves = None
954
955    @property
956    def leaves(self):
957        if self._leaves is None:
958            self._leaves = []
959            for clump in self.tree:
960                if clump.children is None:
961                    self._leaves.append(clump)
962        return self._leaves
963
964    @classmethod
965    def _is_valid(cls, filename, *args, **kwargs):
966        if not filename.endswith(".h5"):
967            return False
968        with h5py.File(filename, mode="r") as f:
969            data_type = parse_h5_attr(f, "data_type")
970            if data_type is None:
971                return False
972            if data_type == "yt_clump_tree":
973                return True
974        return False
975