1import os
2
3import numpy as np
4
5from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
6from yt.data_objects.static_output import Dataset
7from yt.fields.field_info_container import NullFunc
8from yt.frontends.enzo.misc import cosmology_get_units
9from yt.frontends.enzo_e.fields import EnzoEFieldInfo
10from yt.frontends.enzo_e.misc import (
11    get_block_info,
12    get_child_index,
13    get_root_block_id,
14    get_root_blocks,
15    is_parent,
16    nested_dict_get,
17)
18from yt.funcs import get_pbar, setdefaultattr
19from yt.geometry.grid_geometry_handler import GridIndex
20from yt.utilities.cosmology import Cosmology
21from yt.utilities.logger import ytLogger as mylog
22from yt.utilities.on_demand_imports import _h5py as h5py, _libconf as libconf
23
24
25class EnzoEGrid(AMRGridPatch):
26    """
27    Class representing a single EnzoE Grid instance.
28    """
29
30    _id_offset = 0
31    _refine_by = 2
32
33    def __init__(self, id, index, block_name, filename=None):
34        """
35        Returns an instance of EnzoEGrid with *id*, associated with
36        *filename* and *index*.
37        """
38        # All of the field parameters will be passed to us as needed.
39        AMRGridPatch.__init__(self, id, filename=filename, index=index)
40        self.block_name = block_name
41        self._children_ids = None
42        self._parent_id = -1
43        self.Level = -1
44
45    def __repr__(self):
46        return "EnzoEGrid_%04d" % self.id
47
48    def _prepare_grid(self):
49        """Copies all the appropriate attributes from the index."""
50        h = self.index  # cache it
51        my_ind = self.id - self._id_offset
52        self.ActiveDimensions = h.grid_dimensions[my_ind]
53        self.LeftEdge = h.grid_left_edge[my_ind]
54        self.RightEdge = h.grid_right_edge[my_ind]
55
56    def get_parent_id(self, desc_block_name):
57        if self.block_name == desc_block_name:
58            raise RuntimeError("Child and parent are the same!")
59        dim = self.ds.dimensionality
60        d_block = desc_block_name[1:].replace(":", "")
61        parent = self
62
63        while True:
64            a_block = parent.block_name[1:].replace(":", "")
65            gengap = (len(d_block) - len(a_block)) / dim
66            if gengap <= 1:
67                return parent.id
68            cid = get_child_index(a_block, d_block)
69            parent = self.index.grids[parent._children_ids[cid]]
70
71    def add_child(self, child):
72        if self._children_ids is None:
73            self._children_ids = -1 * np.ones(
74                self._refine_by ** self.ds.dimensionality, dtype=np.int64
75            )
76
77        a_block = self.block_name[1:].replace(":", "")
78        d_block = child.block_name[1:].replace(":", "")
79        cid = get_child_index(a_block, d_block)
80        self._children_ids[cid] = child.id
81
82    _particle_count = None
83
84    @property
85    def particle_count(self):
86        if self._particle_count is None:
87            with h5py.File(self.filename, mode="r") as f:
88                fnstr = "{}/{}".format(
89                    self.block_name,
90                    self.ds.index.io._sep.join(["particle", "%s", "%s"]),
91                )
92                self._particle_count = {
93                    ptype: f.get(fnstr % (ptype, pfield)).size
94                    for ptype, pfield in self.ds.index.io.sample_pfields.items()
95                }
96        return self._particle_count
97
98    _total_particles = None
99
100    @property
101    def total_particles(self):
102        if self._total_particles is None:
103            self._total_particles = sum(self.particle_count.values())
104        return self._total_particles
105
106    @property
107    def Parent(self):
108        if self._parent_id == -1:
109            return None
110        return self.index.grids[self._parent_id]
111
112    @property
113    def Children(self):
114        if self._children_ids is None:
115            return []
116        return [self.index.grids[cid] for cid in self._children_ids]
117
118
119class EnzoEHierarchy(GridIndex):
120
121    _strip_path = False
122    grid = EnzoEGrid
123    _preload_implemented = True
124
125    def __init__(self, ds, dataset_type):
126
127        self.dataset_type = dataset_type
128        self.directory = os.path.dirname(ds.parameter_filename)
129        self.index_filename = ds.parameter_filename
130        if os.path.getsize(self.index_filename) == 0:
131            raise OSError(-1, "File empty", self.index_filename)
132
133        GridIndex.__init__(self, ds, dataset_type)
134        self.dataset.dataset_type = self.dataset_type
135
136    def _count_grids(self):
137        fblock_size = 32768
138        f = open(self.ds.parameter_filename)
139        f.seek(0, 2)
140        file_size = f.tell()
141        nblocks = np.ceil(float(file_size) / fblock_size).astype(np.int64)
142        f.seek(0)
143        offset = f.tell()
144        ngrids = 0
145        for _ in range(nblocks):
146            my_block = min(fblock_size, file_size - offset)
147            buff = f.read(my_block)
148            ngrids += buff.count("\n")
149            offset += my_block
150        f.close()
151        self.num_grids = ngrids
152        self.dataset_type = "enzo_e"
153
154    def _parse_index(self):
155        self.grids = np.empty(self.num_grids, dtype="object")
156
157        c = 1
158        pbar = get_pbar("Parsing Hierarchy", self.num_grids)
159        f = open(self.ds.parameter_filename)
160        fblock_size = 32768
161        f.seek(0, 2)
162        file_size = f.tell()
163        nblocks = np.ceil(float(file_size) / fblock_size).astype(np.int64)
164        f.seek(0)
165        offset = f.tell()
166        lstr = ""
167        # place child blocks after the root blocks
168        rbdim = self.ds.root_block_dimensions
169        nroot_blocks = rbdim.prod()
170        child_id = nroot_blocks
171
172        last_pid = None
173        for _ib in range(nblocks):
174            fblock = min(fblock_size, file_size - offset)
175            buff = lstr + f.read(fblock)
176            bnl = 0
177            for _inl in range(buff.count("\n")):
178                nnl = buff.find("\n", bnl)
179                line = buff[bnl:nnl]
180                block_name, block_file = line.split()
181
182                # Handling of the B, B_, and B__ blocks is consistent with
183                # other unrefined blocks
184                level, left, right = get_block_info(block_name)
185                rbindex = get_root_block_id(block_name)
186                rbid = (
187                    rbindex[0] * rbdim[1:].prod()
188                    + rbindex[1] * rbdim[2:].prod()
189                    + rbindex[2]
190                )
191
192                # There are also blocks at lower level than the
193                # real root blocks. These can be ignored.
194                if level == 0:
195                    check_root = get_root_blocks(block_name).prod()
196                    if check_root < nroot_blocks:
197                        level = -1
198
199                if level == -1:
200                    grid_id = child_id
201                    parent_id = -1
202                    child_id += 1
203                elif level == 0:
204                    grid_id = rbid
205                    parent_id = -1
206                else:
207                    grid_id = child_id
208                    # Try the last parent_id first
209                    if last_pid is not None and is_parent(
210                        self.grids[last_pid].block_name, block_name
211                    ):
212                        parent_id = last_pid
213                    else:
214                        parent_id = self.grids[rbid].get_parent_id(block_name)
215                    last_pid = parent_id
216                    child_id += 1
217
218                my_grid = self.grid(
219                    grid_id,
220                    self,
221                    block_name,
222                    filename=os.path.join(self.directory, block_file),
223                )
224                my_grid.Level = level
225                my_grid._parent_id = parent_id
226
227                self.grids[grid_id] = my_grid
228                self.grid_levels[grid_id] = level
229                self.grid_left_edge[grid_id] = left
230                self.grid_right_edge[grid_id] = right
231                self.grid_dimensions[grid_id] = self.ds.active_grid_dimensions
232
233                if level > 0:
234                    self.grids[parent_id].add_child(my_grid)
235
236                bnl = nnl + 1
237                pbar.update(c)
238                c += 1
239            lstr = buff[bnl:]
240            offset += fblock
241
242        f.close()
243        pbar.finish()
244
245        slope = self.ds.domain_width / self.ds.arr(np.ones(3), "code_length")
246        self.grid_left_edge = self.grid_left_edge * slope + self.ds.domain_left_edge
247        self.grid_right_edge = self.grid_right_edge * slope + self.ds.domain_left_edge
248
249    def _populate_grid_objects(self):
250        for g in self.grids:
251            g._prepare_grid()
252            g._setup_dx()
253        self.max_level = self.grid_levels.max()
254
255    def _setup_derived_fields(self):
256        super()._setup_derived_fields()
257        for fname, field in self.ds.field_info.items():
258            if not field.particle_type:
259                continue
260            if isinstance(fname, tuple):
261                continue
262            if field._function is NullFunc:
263                continue
264
265    def _get_particle_type_counts(self):
266        return {
267            ptype: sum(g.particle_count[ptype] for g in self.grids)
268            for ptype in self.ds.particle_types_raw
269        }
270
271    def _detect_output_fields(self):
272        self.field_list = []
273        # Do this only on the root processor to save disk work.
274        if self.comm.rank in (0, None):
275            # Just check the first grid.
276            grid = self.grids[0]
277            field_list, ptypes = self.io._read_field_names(grid)
278            mylog.debug("Grid %s has: %s", grid.id, field_list)
279        else:
280            field_list = None
281            ptypes = None
282        self.field_list = list(self.comm.mpi_bcast(field_list))
283        self.dataset.particle_types = list(self.comm.mpi_bcast(ptypes))
284        self.dataset.particle_types_raw = self.dataset.particle_types[:]
285
286
287class EnzoEDataset(Dataset):
288    """
289    Enzo-E-specific output, set at a fixed time.
290    """
291
292    refine_by = 2
293    _index_class = EnzoEHierarchy
294    _field_info_class = EnzoEFieldInfo
295    _suffix = ".block_list"
296    particle_types = None
297    particle_types_raw = None
298
299    def __init__(
300        self,
301        filename,
302        dataset_type=None,
303        file_style=None,
304        parameter_override=None,
305        conversion_override=None,
306        storage_filename=None,
307        units_override=None,
308        unit_system="cgs",
309        default_species_fields=None,
310    ):
311        """
312        This class is a stripped down class that simply reads and parses
313        *filename* without looking at the index.  *dataset_type* gets passed
314        to the index to pre-determine the style of data-output.  However,
315        it is not strictly necessary.  Optionally you may specify a
316        *parameter_override* dictionary that will override anything in the
317        parameter file and a *conversion_override* dictionary that consists
318        of {fieldname : conversion_to_cgs} that will override the #DataCGS.
319        """
320        self.fluid_types += ("enzoe",)
321        if parameter_override is None:
322            parameter_override = {}
323        self._parameter_override = parameter_override
324        if conversion_override is None:
325            conversion_override = {}
326        self._conversion_override = conversion_override
327        self.storage_filename = storage_filename
328        Dataset.__init__(
329            self,
330            filename,
331            dataset_type,
332            file_style=file_style,
333            units_override=units_override,
334            unit_system=unit_system,
335            default_species_fields=default_species_fields,
336        )
337
338    def _parse_parameter_file(self):
339        """
340        Parses the parameter file and establishes the various
341        dictionaries.
342        """
343
344        f = open(self.parameter_filename)
345        # get dimension from first block name
346        b0, fn0 = f.readline().strip().split()
347        level0, left0, right0 = get_block_info(b0, min_dim=0)
348        root_blocks = get_root_blocks(b0)
349        f.close()
350        self.dimensionality = left0.size
351        self._periodicity = tuple(np.ones(self.dimensionality, dtype=bool))
352
353        lcfn = self.parameter_filename[: -len(self._suffix)] + ".libconfig"
354        if os.path.exists(lcfn):
355            with open(lcfn) as lf:
356                self.parameters = libconf.load(lf)
357            cosmo = nested_dict_get(self.parameters, ("Physics", "cosmology"))
358            if cosmo is not None:
359                self.cosmological_simulation = 1
360                co_pars = [
361                    "hubble_constant_now",
362                    "omega_matter_now",
363                    "omega_lambda_now",
364                    "comoving_box_size",
365                    "initial_redshift",
366                ]
367                co_dict = {
368                    attr: nested_dict_get(
369                        self.parameters, ("Physics", "cosmology", attr)
370                    )
371                    for attr in co_pars
372                }
373                for attr in ["hubble_constant", "omega_matter", "omega_lambda"]:
374                    setattr(self, attr, co_dict[f"{attr}_now"])
375
376                # Current redshift is not stored, so it's not possible
377                # to set all cosmological units yet.
378                # Get the time units and use that to figure out redshift.
379                k = cosmology_get_units(
380                    self.hubble_constant,
381                    self.omega_matter,
382                    co_dict["comoving_box_size"],
383                    co_dict["initial_redshift"],
384                    0,
385                )
386                setdefaultattr(self, "time_unit", self.quan(k["utim"], "s"))
387                co = Cosmology(
388                    hubble_constant=self.hubble_constant,
389                    omega_matter=self.omega_matter,
390                    omega_lambda=self.omega_lambda,
391                )
392            else:
393                self.cosmological_simulation = 0
394        else:
395            self.cosmological_simulation = 0
396
397        fh = h5py.File(os.path.join(self.directory, fn0), "r")
398        self.domain_left_edge = fh.attrs["lower"]
399        self.domain_right_edge = fh.attrs["upper"]
400
401        # all blocks are the same size
402        ablock = fh[list(fh.keys())[0]]
403        self.current_time = ablock.attrs["time"][0]
404        self.parameters["current_cycle"] = ablock.attrs["cycle"][0]
405        gsi = ablock.attrs["enzo_GridStartIndex"]
406        gei = ablock.attrs["enzo_GridEndIndex"]
407        self.ghost_zones = gsi[0]
408        self.root_block_dimensions = root_blocks
409        self.active_grid_dimensions = gei - gsi + 1
410        self.grid_dimensions = ablock.attrs["enzo_GridDimension"]
411        self.domain_dimensions = root_blocks * self.active_grid_dimensions
412        fh.close()
413
414        if self.cosmological_simulation:
415            self.current_redshift = co.z_from_t(self.current_time * self.time_unit)
416
417        self._periodicity += (False,) * (3 - self.dimensionality)
418        self.gamma = nested_dict_get(self.parameters, ("Field", "gamma"))
419
420    def _set_code_unit_attributes(self):
421        if self.cosmological_simulation:
422            box_size = self.parameters["Physics"]["cosmology"]["comoving_box_size"]
423            k = cosmology_get_units(
424                self.hubble_constant,
425                self.omega_matter,
426                box_size,
427                self.parameters["Physics"]["cosmology"]["initial_redshift"],
428                self.current_redshift,
429            )
430            # Now some CGS values
431            setdefaultattr(self, "length_unit", self.quan(box_size, "Mpccm/h"))
432            setdefaultattr(
433                self,
434                "mass_unit",
435                self.quan(k["urho"], "g/cm**3") * (self.length_unit.in_cgs()) ** 3,
436            )
437            setdefaultattr(self, "velocity_unit", self.quan(k["uvel"], "cm/s"))
438        else:
439            p = self.parameters
440            for d, u in zip(("length", "time"), ("cm", "s")):
441                val = nested_dict_get(p, ("Units", d), default=1)
442                setdefaultattr(self, f"{d}_unit", self.quan(val, u))
443            mass = nested_dict_get(p, ("Units", "mass"))
444            if mass is None:
445                density = nested_dict_get(p, ("Units", "density"))
446                if density is not None:
447                    mass = density * self.length_unit ** 3
448                else:
449                    mass = 1
450            setdefaultattr(self, "mass_unit", self.quan(mass, "g"))
451            setdefaultattr(self, "velocity_unit", self.length_unit / self.time_unit)
452
453        magnetic_unit = np.sqrt(
454            4 * np.pi * self.mass_unit / (self.time_unit ** 2 * self.length_unit)
455        )
456        magnetic_unit = np.float64(magnetic_unit.in_cgs())
457        setdefaultattr(self, "magnetic_unit", self.quan(magnetic_unit, "gauss"))
458
459    def __str__(self):
460        return self.basename[: -len(self._suffix)]
461
462    @classmethod
463    def _is_valid(cls, filename, *args, **kwargs):
464        ddir = os.path.dirname(filename)
465        if not filename.endswith(cls._suffix):
466            return False
467        try:
468            with open(filename) as f:
469                block, block_file = f.readline().strip().split()
470                get_block_info(block)
471                if not os.path.exists(os.path.join(ddir, block_file)):
472                    return False
473        except Exception:
474            return False
475        return True
476