1import os
2import weakref
3
4import numpy as np
5
6from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
7from yt.data_objects.static_output import Dataset
8from yt.funcs import mylog, setdefaultattr
9from yt.geometry.grid_geometry_handler import GridIndex
10from yt.utilities.file_handler import HDF5FileHandler
11
12from .definitions import geometry_parameters
13from .fields import GAMERFieldInfo
14
15
16class GAMERGrid(AMRGridPatch):
17    _id_offset = 0
18
19    def __init__(self, id, index, level):
20        AMRGridPatch.__init__(self, id, filename=index.index_filename, index=index)
21        self.Parent = None  # do NOT initialize Parent as []
22        self.Children = []
23        self.Level = level
24
25    def __repr__(self):
26        return "GAMERGrid_%09i (dimension = %s)" % (self.id, self.ActiveDimensions)
27
28
29class GAMERHierarchy(GridIndex):
30    grid = GAMERGrid
31    _preload_implemented = True  # since gamer defines "_read_chunk_data" in io.py
32
33    def __init__(self, ds, dataset_type="gamer"):
34        self.dataset_type = dataset_type
35        self.dataset = weakref.proxy(ds)
36        self.index_filename = self.dataset.parameter_filename
37        self.directory = os.path.dirname(self.index_filename)
38        self._handle = ds._handle
39        self._group_grid = ds._group_grid
40        self._group_particle = ds._group_particle
41        self.float_type = "float64"  # fixed even when FLOAT8 is off
42        self._particle_handle = ds._particle_handle
43        self.refine_by = ds.refine_by
44        self.pgroup = self.refine_by ** 3  # number of patches in a patch group
45        GridIndex.__init__(self, ds, dataset_type)
46
47    def _detect_output_fields(self):
48        # find all field names in the current dataset
49        # grid fields
50        self.field_list = [("gamer", v) for v in self._group_grid.keys()]
51
52        # particle fields
53        if self._group_particle is not None:
54            self.field_list += [("io", v) for v in self._group_particle.keys()]
55
56    def _count_grids(self):
57        # count the total number of patches at all levels
58        self.num_grids = self.dataset.parameters["NPatch"].sum() // self.pgroup
59
60    def _parse_index(self):
61        parameters = self.dataset.parameters
62        gid0 = 0
63        grid_corner = self._handle["Tree/Corner"][()][:: self.pgroup]
64        convert2physical = self._handle["Tree/Corner"].attrs["Cvt2Phy"]
65
66        self.grid_dimensions[:] = parameters["PatchSize"] * self.refine_by
67
68        for lv in range(0, parameters["NLevel"]):
69            num_grids_level = parameters["NPatch"][lv] // self.pgroup
70            if num_grids_level == 0:
71                break
72
73            patch_scale = (
74                parameters["PatchSize"] * parameters["CellScale"][lv] * self.refine_by
75            )
76
77            # set the level and edge of each grid
78            # (left/right_edge are YT arrays in code units)
79            self.grid_levels.flat[gid0 : gid0 + num_grids_level] = lv
80            self.grid_left_edge[gid0 : gid0 + num_grids_level] = (
81                grid_corner[gid0 : gid0 + num_grids_level] * convert2physical
82            )
83            self.grid_right_edge[gid0 : gid0 + num_grids_level] = (
84                grid_corner[gid0 : gid0 + num_grids_level] + patch_scale
85            ) * convert2physical
86
87            gid0 += num_grids_level
88        self.grid_left_edge += self.dataset.domain_left_edge
89        self.grid_right_edge += self.dataset.domain_left_edge
90
91        # allocate all grid objects
92        self.grids = np.empty(self.num_grids, dtype="object")
93        for i in range(self.num_grids):
94            self.grids[i] = self.grid(i, self, self.grid_levels.flat[i])
95
96        # maximum level with patches (which can be lower than MAX_LEVEL)
97        self.max_level = self.grid_levels.max()
98
99        # number of particles in each grid
100        try:
101            self.grid_particle_count[:] = np.sum(
102                self._handle["Tree/NPar"][()].reshape(-1, self.pgroup), axis=1
103            )[:, None]
104        except KeyError:
105            self.grid_particle_count[:] = 0.0
106
107        # calculate the starting particle indices for each grid (starting from 0)
108        # --> note that the last element must store the total number of particles
109        #    (see _read_particle_coords and _read_particle_fields in io.py)
110        self._particle_indices = np.zeros(self.num_grids + 1, dtype="int64")
111        np.add.accumulate(
112            self.grid_particle_count.squeeze(), out=self._particle_indices[1:]
113        )
114
115    def _populate_grid_objects(self):
116        son_list = self._handle["Tree/Son"][()]
117
118        for gid in range(self.num_grids):
119            grid = self.grids[gid]
120            son_gid0 = (
121                son_list[gid * self.pgroup : (gid + 1) * self.pgroup] // self.pgroup
122            )
123
124            # set up the parent-children relationship
125            grid.Children = [self.grids[t] for t in son_gid0[son_gid0 >= 0]]
126
127            for son_grid in grid.Children:
128                son_grid.Parent = grid
129
130            # set up other grid attributes
131            grid._prepare_grid()
132            grid._setup_dx()
133
134        # validate the parent-children relationship in the debug mode
135        if self.dataset._debug:
136            self._validate_parent_children_relationship()
137
138    # for _debug mode only
139    def _validate_parent_children_relationship(self):
140        mylog.info("Validating the parent-children relationship ...")
141
142        father_list = self._handle["Tree/Father"][()]
143
144        for grid in self.grids:
145            # parent->children == itself
146            if grid.Parent is not None:
147                assert (
148                    grid in grid.Parent.Children
149                ), "Grid %d, Parent %d, Parent->Children[0] %d" % (
150                    grid.id,
151                    grid.Parent.id,
152                    grid.Parent.Children[0].id,
153                )
154
155            # children->parent == itself
156            for c in grid.Children:
157                assert c.Parent is grid, "Grid %d, Children %d, Children->Parent %d" % (
158                    grid.id,
159                    c.id,
160                    c.Parent.id,
161                )
162
163            # all refinement grids should have parent
164            if grid.Level > 0:
165                assert (
166                    grid.Parent is not None and grid.Parent.id >= 0
167                ), "Grid %d, Level %d, Parent %d" % (
168                    grid.id,
169                    grid.Level,
170                    grid.Parent.id if grid.Parent is not None else -999,
171                )
172
173            # parent index is consistent with the loaded dataset
174            if grid.Level > 0:
175                father_gid = father_list[grid.id * self.pgroup] // self.pgroup
176                assert (
177                    father_gid == grid.Parent.id
178                ), "Grid %d, Level %d, Parent_Found %d, Parent_Expect %d" % (
179                    grid.id,
180                    grid.Level,
181                    grid.Parent.id,
182                    father_gid,
183                )
184
185            # edges between children and parent
186            for c in grid.Children:
187                for d in range(0, 3):
188                    msgL = (
189                        "Grid %d, Child %d, Grid->EdgeL %14.7e, Children->EdgeL %14.7e"
190                        % (grid.id, c.id, grid.LeftEdge[d], c.LeftEdge[d])
191                    )
192                    msgR = (
193                        "Grid %d, Child %d, Grid->EdgeR %14.7e, Children->EdgeR %14.7e"
194                        % (grid.id, c.id, grid.RightEdge[d], c.RightEdge[d])
195                    )
196                    if not grid.LeftEdge[d] <= c.LeftEdge[d]:
197                        raise ValueError(msgL)
198
199                    if not grid.RightEdge[d] >= c.RightEdge[d]:
200                        raise ValueError(msgR)
201
202        mylog.info("Check passed")
203
204
205class GAMERDataset(Dataset):
206    _index_class = GAMERHierarchy
207    _field_info_class = GAMERFieldInfo
208    _handle = None
209    _group_grid = None
210    _group_particle = None
211    _debug = False  # debug mode for the GAMER frontend
212
213    def __init__(
214        self,
215        filename,
216        dataset_type="gamer",
217        storage_filename=None,
218        particle_filename=None,
219        units_override=None,
220        unit_system="cgs",
221        default_species_fields=None,
222    ):
223
224        if self._handle is not None:
225            return
226
227        self.fluid_types += ("gamer",)
228        self._handle = HDF5FileHandler(filename)
229        self.particle_filename = particle_filename
230
231        # to catch both the new and old data formats for the grid data
232        try:
233            self._group_grid = self._handle["GridData"]
234        except KeyError:
235            self._group_grid = self._handle["Data"]
236
237        if "Particle" in self._handle:
238            self._group_particle = self._handle["Particle"]
239
240        if self.particle_filename is None:
241            self._particle_handle = self._handle
242        else:
243            self._particle_handle = HDF5FileHandler(self.particle_filename)
244
245        # currently GAMER only supports refinement by a factor of 2
246        self.refine_by = 2
247
248        Dataset.__init__(
249            self,
250            filename,
251            dataset_type,
252            units_override=units_override,
253            unit_system=unit_system,
254            default_species_fields=default_species_fields,
255        )
256        self.storage_filename = storage_filename
257
258    def _set_code_unit_attributes(self):
259        if self.parameters["Opt__Unit"]:
260            # GAMER units are always in CGS
261            setdefaultattr(
262                self, "length_unit", self.quan(self.parameters["Unit_L"], "cm")
263            )
264            setdefaultattr(self, "mass_unit", self.quan(self.parameters["Unit_M"], "g"))
265            setdefaultattr(self, "time_unit", self.quan(self.parameters["Unit_T"], "s"))
266
267            if self.mhd:
268                setdefaultattr(
269                    self, "magnetic_unit", self.quan(self.parameters["Unit_B"], "gauss")
270                )
271
272        else:
273            if len(self.units_override) == 0:
274                mylog.warning(
275                    "Cannot determine code units ==> "
276                    "Use units_override to specify the units"
277                )
278
279            for unit, value, cgs in [
280                ("length", 1.0, "cm"),
281                ("time", 1.0, "s"),
282                ("mass", 1.0, "g"),
283                ("magnetic", np.sqrt(4.0 * np.pi), "gauss"),
284            ]:
285                setdefaultattr(self, f"{unit}_unit", self.quan(value, cgs))
286
287                if len(self.units_override) == 0:
288                    mylog.warning("Assuming %8s unit = %f %s", unit, value, cgs)
289
290    def _parse_parameter_file(self):
291
292        # code-specific parameters
293        for t in self._handle["Info"]:
294            info_category = self._handle["Info"][t]
295            for v in info_category.dtype.names:
296                self.parameters[v] = info_category[v]
297
298        # shortcut for self.parameters
299        parameters = self.parameters
300
301        # reset 'Model' to be more readable
302        # (no longer regard MHD as a separate model)
303        if parameters["Model"] == 1:
304            parameters["Model"] = "Hydro"
305        elif parameters["Model"] == 3:
306            parameters["Model"] = "ELBDM"
307        else:
308            parameters["Model"] = "Unknown"
309
310        # simulation time and domain
311        self.current_time = parameters["Time"][0]
312        self.dimensionality = 3  # always 3D
313        self.domain_left_edge = parameters.get(
314            "BoxEdgeL", np.array([0.0, 0.0, 0.0])
315        ).astype("f8")
316        self.domain_right_edge = parameters.get(
317            "BoxEdgeR", parameters["BoxSize"]
318        ).astype("f8")
319        self.domain_dimensions = parameters["NX0"].astype("int64")
320
321        # periodicity
322        if parameters["FormatVersion"] >= 2106:
323            periodic_bc = 1
324        else:
325            periodic_bc = 0
326        self._periodicity = (
327            bool(parameters["Opt__BC_Flu"][0] == periodic_bc),
328            bool(parameters["Opt__BC_Flu"][2] == periodic_bc),
329            bool(parameters["Opt__BC_Flu"][4] == periodic_bc),
330        )
331
332        # cosmological parameters
333        if parameters["Comoving"]:
334            self.cosmological_simulation = 1
335            self.current_redshift = 1.0 / self.current_time - 1.0
336            self.omega_matter = parameters["OmegaM0"]
337            self.omega_lambda = 1.0 - self.omega_matter
338            # default to 0.7 for old data format
339            self.hubble_constant = parameters.get("Hubble0", 0.7)
340        else:
341            self.cosmological_simulation = 0
342            self.current_redshift = 0.0
343            self.omega_matter = 0.0
344            self.omega_lambda = 0.0
345            self.hubble_constant = 0.0
346
347        # make aliases to some frequently used variables
348        if parameters["Model"] == "Hydro":
349            self.gamma = parameters["Gamma"]
350            self.eos = parameters.get("EoS", 1)  # Assume gamma-law by default
351            # default to 0.6 for old data format
352            self.mu = parameters.get(
353                "MolecularWeight", 0.6
354            )  # Assume ionized primordial by default
355            self.mhd = parameters.get("Magnetohydrodynamics", 0)
356            self.srhd = parameters.get("SRHydrodynamics", 0)
357        else:
358            self.mhd = 0
359            self.srhd = 0
360
361        # old data format (version < 2210) did not contain any information of code units
362        self.parameters.setdefault("Opt__Unit", 0)
363
364        self.geometry = geometry_parameters[parameters.get("Coordinate", 1)]
365
366    @classmethod
367    def _is_valid(cls, filename, *args, **kwargs):
368        try:
369            # define a unique way to identify GAMER datasets
370            f = HDF5FileHandler(filename)
371            if "Info" in f["/"].keys() and "KeyInfo" in f["/Info"].keys():
372                return True
373        except Exception:
374            pass
375        return False
376