1import os
2import weakref
3
4import numpy as np
5
6from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
7from yt.data_objects.static_output import Dataset
8from yt.funcs import just_one, setdefaultattr
9from yt.geometry.grid_geometry_handler import GridIndex
10from yt.units.dimensions import dimensionless as sympy_one
11from yt.units.unit_object import Unit
12from yt.units.unit_systems import unit_system_registry
13from yt.utilities.exceptions import YTGDFUnknownGeometry
14from yt.utilities.lib.misc_utilities import get_box_grids_level
15from yt.utilities.logger import ytLogger as mylog
16from yt.utilities.on_demand_imports import _h5py as h5py
17
18from .fields import GDFFieldInfo
19
20GEOMETRY_TRANS = {
21    0: "cartesian",
22    1: "polar",
23    2: "cylindrical",
24    3: "spherical",
25}
26
27
28class GDFGrid(AMRGridPatch):
29    _id_offset = 0
30
31    def __init__(self, id, index, level, start, dimensions):
32        AMRGridPatch.__init__(self, id, filename=index.index_filename, index=index)
33        self.Parent = []
34        self.Children = []
35        self.Level = level
36        self.start_index = start.copy()
37        self.stop_index = self.start_index + dimensions
38        self.ActiveDimensions = dimensions.copy()
39
40    def _setup_dx(self):
41        # So first we figure out what the index is.  We don't assume
42        # that dx=dy=dz , at least here.  We probably do elsewhere.
43        id = self.id - self._id_offset
44        if len(self.Parent) > 0:
45            self.dds = self.Parent[0].dds / self.ds.refine_by
46        else:
47            LE, RE = self.index.grid_left_edge[id, :], self.index.grid_right_edge[id, :]
48            self.dds = np.array((RE - LE) / self.ActiveDimensions)
49        if self.ds.data_software != "piernik":
50            if self.ds.dimensionality < 2:
51                self.dds[1] = 1.0
52            if self.ds.dimensionality < 3:
53                self.dds[2] = 1.0
54        self.field_data["dx"], self.field_data["dy"], self.field_data["dz"] = self.dds
55        self.dds = self.ds.arr(self.dds, "code_length")
56
57
58class GDFHierarchy(GridIndex):
59
60    grid = GDFGrid
61
62    def __init__(self, ds, dataset_type="grid_data_format"):
63        self.dataset = weakref.proxy(ds)
64        self.index_filename = self.dataset.parameter_filename
65        h5f = h5py.File(self.index_filename, mode="r")
66        self.dataset_type = dataset_type
67        GridIndex.__init__(self, ds, dataset_type)
68        self.directory = os.path.dirname(self.index_filename)
69        h5f.close()
70
71    def _detect_output_fields(self):
72        h5f = h5py.File(self.index_filename, mode="r")
73        self.field_list = [("gdf", str(f)) for f in h5f["field_types"].keys()]
74        h5f.close()
75
76    def _count_grids(self):
77        h5f = h5py.File(self.index_filename, mode="r")
78        self.num_grids = h5f["/grid_parent_id"].shape[0]
79        h5f.close()
80
81    def _parse_index(self):
82        h5f = h5py.File(self.index_filename, mode="r")
83        dxs = []
84        self.grids = np.empty(self.num_grids, dtype="object")
85        levels = (h5f["grid_level"][:]).copy()
86        glis = (h5f["grid_left_index"][:]).copy()
87        gdims = (h5f["grid_dimensions"][:]).copy()
88        active_dims = ~(
89            (np.max(gdims, axis=0) == 1) & (self.dataset.domain_dimensions == 1)
90        )
91
92        for i in range(levels.shape[0]):
93            self.grids[i] = self.grid(i, self, levels[i], glis[i], gdims[i])
94            self.grids[i]._level_id = levels[i]
95
96            dx = (
97                self.dataset.domain_right_edge - self.dataset.domain_left_edge
98            ) / self.dataset.domain_dimensions
99            dx[active_dims] /= self.dataset.refine_by ** levels[i]
100            dxs.append(dx.in_units("code_length"))
101        dx = self.dataset.arr(dxs, units="code_length")
102        self.grid_left_edge = self.dataset.domain_left_edge + dx * glis
103        self.grid_dimensions = gdims.astype("int32")
104        self.grid_right_edge = self.grid_left_edge + dx * self.grid_dimensions
105        self.grid_particle_count = h5f["grid_particle_count"][:]
106        del levels, glis, gdims
107        h5f.close()
108
109    def _populate_grid_objects(self):
110        mask = np.empty(self.grids.size, dtype="int32")
111        for g in self.grids:
112            g._prepare_grid()
113            g._setup_dx()
114
115        for gi, g in enumerate(self.grids):
116            g.Children = self._get_grid_children(g)
117            for g1 in g.Children:
118                g1.Parent.append(g)
119            get_box_grids_level(
120                self.grid_left_edge[gi, :],
121                self.grid_right_edge[gi, :],
122                self.grid_levels[gi],
123                self.grid_left_edge,
124                self.grid_right_edge,
125                self.grid_levels,
126                mask,
127            )
128            m = mask.astype("bool")
129            m[gi] = False
130            siblings = self.grids[gi:][m[gi:]]
131            if len(siblings) > 0:
132                g.OverlappingSiblings = siblings.tolist()
133        self.max_level = self.grid_levels.max()
134
135    def _get_box_grids(self, left_edge, right_edge):
136        """
137        Gets back all the grids between a left edge and right edge
138        """
139        eps = np.finfo(np.float64).eps
140        grid_i = np.where(
141            np.all((self.grid_right_edge - left_edge) > eps, axis=1)
142            & np.all((right_edge - self.grid_left_edge) > eps, axis=1)
143        )
144
145        return self.grids[grid_i], grid_i
146
147    def _get_grid_children(self, grid):
148        mask = np.zeros(self.num_grids, dtype="bool")
149        grids, grid_ind = self._get_box_grids(grid.LeftEdge, grid.RightEdge)
150        mask[grid_ind] = True
151        return [g for g in self.grids[mask] if g.Level == grid.Level + 1]
152
153
154class GDFDataset(Dataset):
155    _index_class = GDFHierarchy
156    _field_info_class = GDFFieldInfo
157
158    def __init__(
159        self,
160        filename,
161        dataset_type="grid_data_format",
162        storage_filename=None,
163        geometry=None,
164        units_override=None,
165        unit_system="cgs",
166        default_species_fields=None,
167    ):
168        self.geometry = geometry
169        self.fluid_types += ("gdf",)
170        Dataset.__init__(
171            self,
172            filename,
173            dataset_type,
174            units_override=units_override,
175            unit_system=unit_system,
176            default_species_fields=default_species_fields,
177        )
178        self.storage_filename = storage_filename
179        self.filename = filename
180
181    def _set_code_unit_attributes(self):
182        """
183        Generates the conversion to various physical _units
184        based on the parameter file
185        """
186
187        # This should be improved.
188        h5f = h5py.File(self.parameter_filename, mode="r")
189        for field_name in h5f["/field_types"]:
190            current_field = h5f[f"/field_types/{field_name}"]
191            if "field_to_cgs" in current_field.attrs:
192                field_conv = current_field.attrs["field_to_cgs"]
193                self.field_units[field_name] = just_one(field_conv)
194            elif "field_units" in current_field.attrs:
195                field_units = current_field.attrs["field_units"]
196                if isinstance(field_units, str):
197                    current_field_units = current_field.attrs["field_units"]
198                else:
199                    current_field_units = just_one(current_field.attrs["field_units"])
200                self.field_units[field_name] = current_field_units.decode("utf8")
201            else:
202                self.field_units[field_name] = ""
203
204        if "dataset_units" in h5f:
205            for unit_name in h5f["/dataset_units"]:
206                current_unit = h5f[f"/dataset_units/{unit_name}"]
207                value = current_unit[()]
208                unit = current_unit.attrs["unit"]
209                # need to convert to a Unit object and check dimensions
210                # because unit can be things like
211                # 'dimensionless/dimensionless**3' so naive string
212                # comparisons are insufficient
213                unit = Unit(unit, registry=self.unit_registry)
214                if unit_name.endswith("_unit") and unit.dimensions is sympy_one:
215                    # Catch code units and if they are dimensionless,
216                    # assign CGS units. setdefaultattr will catch code units
217                    # which have already been set via units_override.
218                    un = unit_name[:-5]
219                    un = un.replace("magnetic", "magnetic_field_cgs", 1)
220                    unit = unit_system_registry["cgs"][un]
221                    setdefaultattr(self, unit_name, self.quan(value, unit))
222                setdefaultattr(self, unit_name, self.quan(value, unit))
223                if unit_name in h5f["/field_types"]:
224                    if unit_name in self.field_units:
225                        mylog.warning(
226                            "'field_units' was overridden by 'dataset_units/%s'",
227                            unit_name,
228                        )
229                    self.field_units[unit_name] = str(unit)
230        else:
231            setdefaultattr(self, "length_unit", self.quan(1.0, "cm"))
232            setdefaultattr(self, "mass_unit", self.quan(1.0, "g"))
233            setdefaultattr(self, "time_unit", self.quan(1.0, "s"))
234
235        h5f.close()
236
237    def _parse_parameter_file(self):
238        self._handle = h5py.File(self.parameter_filename, mode="r")
239        if "data_software" in self._handle["gridded_data_format"].attrs:
240            self.data_software = self._handle["gridded_data_format"].attrs[
241                "data_software"
242            ]
243        else:
244            self.data_software = "unknown"
245        sp = self._handle["/simulation_parameters"].attrs
246        if self.geometry is None:
247            geometry = just_one(sp.get("geometry", 0))
248            try:
249                self.geometry = GEOMETRY_TRANS[geometry]
250            except KeyError as e:
251                raise YTGDFUnknownGeometry(geometry) from e
252        self.parameters.update(sp)
253        self.domain_left_edge = sp["domain_left_edge"][:]
254        self.domain_right_edge = sp["domain_right_edge"][:]
255        self.domain_dimensions = sp["domain_dimensions"][:]
256        refine_by = sp["refine_by"]
257        if refine_by is None:
258            refine_by = 2
259        self.refine_by = refine_by
260        self.dimensionality = sp["dimensionality"]
261        self.current_time = sp["current_time"]
262        self.unique_identifier = sp["unique_identifier"]
263        self.cosmological_simulation = sp["cosmological_simulation"]
264        if sp["num_ghost_zones"] != 0:
265            raise RuntimeError
266        self.num_ghost_zones = sp["num_ghost_zones"]
267        self.field_ordering = sp["field_ordering"]
268        self.boundary_conditions = sp["boundary_conditions"][:]
269        self._periodicity = tuple(bnd == 0 for bnd in self.boundary_conditions[::2])
270        if self.cosmological_simulation:
271            self.current_redshift = sp["current_redshift"]
272            self.omega_lambda = sp["omega_lambda"]
273            self.omega_matter = sp["omega_matter"]
274            self.hubble_constant = sp["hubble_constant"]
275        else:
276            self.current_redshift = 0.0
277            self.omega_lambda = 0.0
278            self.omega_matter = 0.0
279            self.hubble_constant = 0.0
280            self.cosmological_simulation = 0
281        self.parameters["Time"] = 1.0  # Hardcode time conversion for now.
282        # Hardcode for now until field staggering is supported.
283        self.parameters["HydroMethod"] = 0
284        self._handle.close()
285        del self._handle
286
287    @classmethod
288    def _is_valid(cls, filename, *args, **kwargs):
289        try:
290            fileh = h5py.File(filename, mode="r")
291            if "gridded_data_format" in fileh:
292                fileh.close()
293                return True
294            fileh.close()
295        except Exception:
296            pass
297        return False
298
299    def __str__(self):
300        return self.basename.rsplit(".", 1)[0]
301