1from functools import reduce
2from operator import mul
3from os import listdir, path
4from re import match
5
6import numpy as np
7from packaging.version import Version
8
9from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
10from yt.data_objects.static_output import Dataset
11from yt.data_objects.time_series import DatasetSeries
12from yt.frontends.open_pmd.fields import OpenPMDFieldInfo
13from yt.frontends.open_pmd.misc import get_component, is_const_component
14from yt.funcs import setdefaultattr
15from yt.geometry.grid_geometry_handler import GridIndex
16from yt.utilities.file_handler import HDF5FileHandler, warn_h5py
17from yt.utilities.logger import ytLogger as mylog
18from yt.utilities.on_demand_imports import _h5py as h5py
19
20ompd_known_versions = [Version(_) for _ in ("1.0.0", "1.0.1", "1.1.0")]
21opmd_required_attributes = ["openPMD", "basePath"]
22
23
24class OpenPMDGrid(AMRGridPatch):
25    """Represents chunk of data on-disk.
26
27    This defines the index and offset for every mesh and particle type.
28    It also defines parents and children grids. Since openPMD does not have multiple
29    levels of refinement there are no parents or children for any grid.
30    """
31
32    _id_offset = 0
33    __slots__ = ["_level_id"]
34    # Every particle species and mesh might have different hdf5-indices and offsets
35    ftypes = []
36    ptypes = []
37    findex = 0
38    foffset = 0
39    pindex = 0
40    poffset = 0
41
42    def __init__(self, gid, index, level=-1, fi=0, fo=0, pi=0, po=0, ft=None, pt=None):
43        AMRGridPatch.__init__(self, gid, filename=index.index_filename, index=index)
44        if ft is None:
45            ft = []
46        if pt is None:
47            pt = []
48        self.findex = fi
49        self.foffset = fo
50        self.pindex = pi
51        self.poffset = po
52        self.ftypes = ft
53        self.ptypes = pt
54        self.Parent = None
55        self.Children = []
56        self.Level = level
57
58    def __str__(self):
59        return "OpenPMDGrid_%04i (%s)" % (self.id, self.ActiveDimensions)
60
61
62class OpenPMDHierarchy(GridIndex):
63    """Defines which fields and particles are created and read from disk.
64
65    Furthermore it defines the characteristics of the grids.
66    """
67
68    grid = OpenPMDGrid
69
70    def __init__(self, ds, dataset_type="openPMD"):
71        self.dataset_type = dataset_type
72        self.dataset = ds
73        self.index_filename = ds.parameter_filename
74        self.directory = path.dirname(self.index_filename)
75        GridIndex.__init__(self, ds, dataset_type)
76
77    def _get_particle_type_counts(self):
78        """Reads the active number of particles for every species.
79
80        Returns
81        -------
82        dict
83            keys are ptypes
84            values are integer counts of the ptype
85        """
86        result = {}
87        f = self.dataset._handle
88        bp = self.dataset.base_path
89        pp = self.dataset.particles_path
90
91        try:
92            for ptype in self.ds.particle_types_raw:
93                if str(ptype) == "io":
94                    spec = list(f[bp + pp].keys())[0]
95                else:
96                    spec = ptype
97                axis = list(f[bp + pp + "/" + spec + "/position"].keys())[0]
98                pos = f[bp + pp + "/" + spec + "/position/" + axis]
99                if is_const_component(pos):
100                    result[ptype] = pos.attrs["shape"]
101                else:
102                    result[ptype] = pos.len()
103        except (KeyError):
104            result["io"] = 0
105
106        return result
107
108    def _detect_output_fields(self):
109        """Populates ``self.field_list`` with native fields (mesh and particle) on disk.
110
111        Each entry is a tuple of two strings. The first element is the on-disk fluid
112        type or particle type. The second element is the name of the field in yt.
113        This string is later used for accessing the data.
114        Convention suggests that the on-disk fluid type should be "openPMD",
115        the on-disk particle type (for a single species of particles) is "io"
116        or (for multiple species of particles) the particle name on-disk.
117        """
118        f = self.dataset._handle
119        bp = self.dataset.base_path
120        mp = self.dataset.meshes_path
121        pp = self.dataset.particles_path
122
123        mesh_fields = []
124        try:
125            meshes = f[bp + mp]
126            for mname in meshes.keys():
127                try:
128                    mesh = meshes[mname]
129                    for axis in mesh.keys():
130                        mesh_fields.append(mname.replace("_", "-") + "_" + axis)
131                except AttributeError:
132                    # This is a h5py.Dataset (i.e. no axes)
133                    mesh_fields.append(mname.replace("_", "-"))
134        except (KeyError, TypeError, AttributeError):
135            pass
136        self.field_list = [("openPMD", str(field)) for field in mesh_fields]
137
138        particle_fields = []
139        try:
140            particles = f[bp + pp]
141            for pname in particles.keys():
142                species = particles[pname]
143                for recname in species.keys():
144                    record = species[recname]
145                    if is_const_component(record):
146                        # Record itself (e.g. particle_mass) is constant
147                        particle_fields.append(
148                            pname.replace("_", "-") + "_" + recname.replace("_", "-")
149                        )
150                    elif "particlePatches" not in recname:
151                        try:
152                            # Create a field for every axis (x,y,z) of every
153                            # property (position) of every species (electrons)
154                            axes = list(record.keys())
155                            if str(recname) == "position":
156                                recname = "positionCoarse"
157                            for axis in axes:
158                                particle_fields.append(
159                                    pname.replace("_", "-")
160                                    + "_"
161                                    + recname.replace("_", "-")
162                                    + "_"
163                                    + axis
164                                )
165                        except AttributeError:
166                            # Record is a dataset, does not have axes (e.g. weighting)
167                            particle_fields.append(
168                                pname.replace("_", "-")
169                                + "_"
170                                + recname.replace("_", "-")
171                            )
172                            pass
173                    else:
174                        pass
175            if len(list(particles.keys())) > 1:
176                # There is more than one particle species,
177                # use the specific names as field types
178                self.field_list.extend(
179                    [
180                        (
181                            str(field).split("_")[0],
182                            ("particle_" + "_".join(str(field).split("_")[1:])),
183                        )
184                        for field in particle_fields
185                    ]
186                )
187            else:
188                # Only one particle species, fall back to "io"
189                self.field_list.extend(
190                    [
191                        ("io", ("particle_" + "_".join(str(field).split("_")[1:])))
192                        for field in particle_fields
193                    ]
194                )
195        except (KeyError, TypeError, AttributeError):
196            pass
197
198    def _count_grids(self):
199        """Sets ``self.num_grids`` to be the total number of grids in the simulation.
200
201        The number of grids is determined by their respective memory footprint.
202        """
203        f = self.dataset._handle
204        bp = self.dataset.base_path
205        mp = self.dataset.meshes_path
206        pp = self.dataset.particles_path
207
208        self.meshshapes = {}
209        self.numparts = {}
210
211        self.num_grids = 0
212
213        try:
214            meshes = f[bp + mp]
215            for mname in meshes.keys():
216                mesh = meshes[mname]
217                if isinstance(mesh, h5py.Group):
218                    shape = mesh[list(mesh.keys())[0]].shape
219                else:
220                    shape = mesh.shape
221                spacing = tuple(mesh.attrs["gridSpacing"])
222                offset = tuple(mesh.attrs["gridGlobalOffset"])
223                unit_si = mesh.attrs["gridUnitSI"]
224                self.meshshapes[mname] = (shape, spacing, offset, unit_si)
225        except (KeyError, TypeError, AttributeError):
226            pass
227        try:
228            particles = f[bp + pp]
229            for pname in particles.keys():
230                species = particles[pname]
231                if "particlePatches" in species.keys():
232                    for (patch, size) in enumerate(
233                        species["/particlePatches/numParticles"]
234                    ):
235                        self.numparts[f"{pname}#{patch}"] = size
236                else:
237                    axis = list(species["/position"].keys())[0]
238                    if is_const_component(species["/position/" + axis]):
239                        self.numparts[pname] = species["/position/" + axis].attrs[
240                            "shape"
241                        ]
242                    else:
243                        self.numparts[pname] = species["/position/" + axis].len()
244        except (KeyError, TypeError, AttributeError):
245            pass
246
247        # Limit values per grid by resulting memory footprint
248        self.vpg = int(self.dataset.gridsize / 4)  # 4Byte per value (f32)
249
250        # Meshes of the same size do not need separate chunks
251        for shape, *_ in set(self.meshshapes.values()):
252            self.num_grids += min(
253                shape[0], int(np.ceil(reduce(mul, shape) * self.vpg ** -1))
254            )
255
256        # Same goes for particle chunks if they are not inside particlePatches
257        patches = {}
258        no_patches = {}
259        for (k, v) in self.numparts.items():
260            if "#" in k:
261                patches[k] = v
262            else:
263                no_patches[k] = v
264        for size in set(no_patches.values()):
265            self.num_grids += int(np.ceil(size * self.vpg ** -1))
266        for size in patches.values():
267            self.num_grids += int(np.ceil(size * self.vpg ** -1))
268
269    def _parse_index(self):
270        """Fills each grid with appropriate properties (extent, dimensions, ...)
271
272        This calculates the properties of every OpenPMDGrid based on the total number of
273        grids in the simulation. The domain is divided into ``self.num_grids`` (roughly)
274        equally sized chunks along the x-axis. ``grid_levels`` is always equal to 0
275        since we only have one level of refinement in openPMD.
276
277        Notes
278        -----
279        ``self.grid_dimensions`` is rounded to the nearest integer. Grid edges are
280        calculated from this dimension. Grids with dimensions [0, 0, 0] are particle
281        only. The others do not have any particles affiliated with them.
282        """
283        f = self.dataset._handle
284        bp = self.dataset.base_path
285        pp = self.dataset.particles_path
286
287        self.grid_levels.flat[:] = 0
288        self.grids = np.empty(self.num_grids, dtype="object")
289
290        grid_index_total = 0
291
292        # Mesh grids
293        for mesh in set(self.meshshapes.values()):
294            (shape, spacing, offset, unit_si) = mesh
295            shape = np.asarray(shape)
296            spacing = np.asarray(spacing)
297            offset = np.asarray(offset)
298            # Total dimension of this grid
299            domain_dimension = np.asarray(shape, dtype=np.int32)
300            domain_dimension = np.append(
301                domain_dimension, np.ones(3 - len(domain_dimension))
302            )
303            # Number of grids of this shape
304            num_grids = min(shape[0], int(np.ceil(reduce(mul, shape) * self.vpg ** -1)))
305            gle = offset * unit_si  # self.dataset.domain_left_edge
306            gre = (
307                domain_dimension[: spacing.size] * unit_si * spacing + gle
308            )  # self.dataset.domain_right_edge
309            gle = np.append(gle, np.zeros(3 - len(gle)))
310            gre = np.append(gre, np.ones(3 - len(gre)))
311            grid_dim_offset = np.linspace(
312                0, domain_dimension[0], num_grids + 1, dtype=np.int32
313            )
314            grid_edge_offset = (
315                grid_dim_offset * float(domain_dimension[0]) ** -1 * (gre[0] - gle[0])
316                + gle[0]
317            )
318            mesh_names = []
319            for (mname, mdata) in self.meshshapes.items():
320                if mesh == mdata:
321                    mesh_names.append(str(mname))
322            prev = 0
323            for grid in np.arange(num_grids):
324                self.grid_dimensions[grid_index_total] = domain_dimension
325                self.grid_dimensions[grid_index_total][0] = (
326                    grid_dim_offset[grid + 1] - grid_dim_offset[grid]
327                )
328                self.grid_left_edge[grid_index_total] = gle
329                self.grid_left_edge[grid_index_total][0] = grid_edge_offset[grid]
330                self.grid_right_edge[grid_index_total] = gre
331                self.grid_right_edge[grid_index_total][0] = grid_edge_offset[grid + 1]
332                self.grid_particle_count[grid_index_total] = 0
333                self.grids[grid_index_total] = self.grid(
334                    grid_index_total,
335                    self,
336                    0,
337                    fi=prev,
338                    fo=self.grid_dimensions[grid_index_total][0],
339                    ft=mesh_names,
340                )
341                prev += self.grid_dimensions[grid_index_total][0]
342                grid_index_total += 1
343
344        handled_ptypes = []
345
346        # Particle grids
347        for (species, count) in self.numparts.items():
348            if "#" in species:
349                # This is a particlePatch
350                spec = species.split("#")
351                patch = f[bp + pp + "/" + spec[0] + "/particlePatches"]
352                domain_dimension = np.ones(3, dtype=np.int32)
353                for (ind, axis) in enumerate(list(patch["extent"].keys())):
354                    domain_dimension[ind] = patch["extent/" + axis][()][int(spec[1])]
355                num_grids = int(np.ceil(count * self.vpg ** -1))
356                gle = []
357                for axis in patch["offset"].keys():
358                    gle.append(
359                        get_component(patch, "offset/" + axis, int(spec[1]), 1)[0]
360                    )
361                gle = np.asarray(gle)
362                gle = np.append(gle, np.zeros(3 - len(gle)))
363                gre = []
364                for axis in patch["extent"].keys():
365                    gre.append(
366                        get_component(patch, "extent/" + axis, int(spec[1]), 1)[0]
367                    )
368                gre = np.asarray(gre)
369                gre = np.append(gre, np.ones(3 - len(gre)))
370                np.add(gle, gre, gre)
371                npo = patch["numParticlesOffset"][()].item(int(spec[1]))
372                particle_count = np.linspace(
373                    npo, npo + count, num_grids + 1, dtype=np.int32
374                )
375                particle_names = [str(spec[0])]
376            elif str(species) not in handled_ptypes:
377                domain_dimension = self.dataset.domain_dimensions
378                num_grids = int(np.ceil(count * self.vpg ** -1))
379                gle = self.dataset.domain_left_edge
380                gre = self.dataset.domain_right_edge
381                particle_count = np.linspace(0, count, num_grids + 1, dtype=np.int32)
382                particle_names = []
383                for (pname, size) in self.numparts.items():
384                    if size == count:
385                        # Since this is not part of a particlePatch,
386                        # we can include multiple same-sized ptypes
387                        particle_names.append(str(pname))
388                        handled_ptypes.append(str(pname))
389            else:
390                # A grid with this exact particle count has already been created
391                continue
392            for grid in np.arange(num_grids):
393                self.grid_dimensions[grid_index_total] = domain_dimension
394                self.grid_left_edge[grid_index_total] = gle
395                self.grid_right_edge[grid_index_total] = gre
396                self.grid_particle_count[grid_index_total] = (
397                    particle_count[grid + 1] - particle_count[grid]
398                ) * len(particle_names)
399                self.grids[grid_index_total] = self.grid(
400                    grid_index_total,
401                    self,
402                    0,
403                    pi=particle_count[grid],
404                    po=particle_count[grid + 1] - particle_count[grid],
405                    pt=particle_names,
406                )
407                grid_index_total += 1
408
409    def _populate_grid_objects(self):
410        """This initializes all grids.
411
412        Additionally, it should set up Children and Parent lists on each grid object.
413        openPMD is not adaptive and thus there are no Children and Parents for any grid.
414        """
415        for i in np.arange(self.num_grids):
416            self.grids[i]._prepare_grid()
417            self.grids[i]._setup_dx()
418        self.max_level = 0
419
420
421class OpenPMDDataset(Dataset):
422    """Contains all the required information of a single iteration of the simulation.
423
424    Notes
425    -----
426    It is assumed that
427    - all meshes cover the same region. Their resolution can be different.
428    - all particles reside in this same region exclusively.
429    - particle and mesh positions are *absolute* with respect to the simulation origin.
430    """
431
432    _index_class = OpenPMDHierarchy
433    _field_info_class = OpenPMDFieldInfo
434
435    def __init__(
436        self,
437        filename,
438        dataset_type="openPMD",
439        storage_filename=None,
440        units_override=None,
441        unit_system="mks",
442        **kwargs,
443    ):
444        self._handle = HDF5FileHandler(filename)
445        self.gridsize = kwargs.pop("open_pmd_virtual_gridsize", 10 ** 9)
446        self.standard_version = Version(self._handle.attrs["openPMD"].decode())
447        self.iteration = kwargs.pop("iteration", None)
448        self._set_paths(self._handle, path.dirname(filename), self.iteration)
449        Dataset.__init__(
450            self,
451            filename,
452            dataset_type,
453            units_override=units_override,
454            unit_system=unit_system,
455        )
456        self.storage_filename = storage_filename
457        self.fluid_types += ("openPMD",)
458        try:
459            particles = tuple(
460                str(c)
461                for c in self._handle[self.base_path + self.particles_path].keys()
462            )
463            if len(particles) > 1:
464                # Only use on-disk particle names if there is more than one species
465                self.particle_types = particles
466            mylog.debug("self.particle_types: %s", self.particle_types)
467            self.particle_types_raw = self.particle_types
468            self.particle_types = tuple(self.particle_types)
469        except (KeyError, TypeError, AttributeError):
470            pass
471
472    def _set_paths(self, handle, path, iteration):
473        """Parses relevant hdf5-paths out of ``handle``.
474
475        Parameters
476        ----------
477        handle : h5py.File
478        path : str
479            (absolute) filepath for current hdf5 container
480        """
481        iterations = []
482        if iteration is None:
483            iteration = list(handle["/data"].keys())[0]
484        encoding = handle.attrs["iterationEncoding"].decode()
485        if "groupBased" in encoding:
486            iterations = list(handle["/data"].keys())
487            mylog.info("Found %s iterations in file", len(iterations))
488        elif "fileBased" in encoding:
489            itformat = handle.attrs["iterationFormat"].decode().split("/")[-1]
490            regex = "^" + itformat.replace("%T", "[0-9]+") + "$"
491            if path == "":
492                mylog.warning(
493                    "For file based iterations, please use absolute file paths!"
494                )
495                pass
496            for filename in listdir(path):
497                if match(regex, filename):
498                    iterations.append(filename)
499            mylog.info("Found %s iterations in directory", len(iterations))
500
501        if len(iterations) == 0:
502            mylog.warning("No iterations found!")
503        if "groupBased" in encoding and len(iterations) > 1:
504            mylog.warning("Only chose to load one iteration (%s)", iteration)
505
506        self.base_path = f"/data/{iteration}/"
507        try:
508            self.meshes_path = self._handle["/"].attrs["meshesPath"].decode()
509            handle[self.base_path + self.meshes_path]
510        except (KeyError):
511            if self.standard_version <= Version("1.1.0"):
512                mylog.info(
513                    "meshesPath not present in file. "
514                    "Assuming file contains no meshes and has a domain extent of 1m^3!"
515                )
516                self.meshes_path = None
517            else:
518                raise
519        try:
520            self.particles_path = self._handle["/"].attrs["particlesPath"].decode()
521            handle[self.base_path + self.particles_path]
522        except (KeyError):
523            if self.standard_version <= Version("1.1.0"):
524                mylog.info(
525                    "particlesPath not present in file."
526                    " Assuming file contains no particles!"
527                )
528                self.particles_path = None
529            else:
530                raise
531
532    def _set_code_unit_attributes(self):
533        """Handle conversion between different physical units and the code units.
534
535        Every dataset in openPMD can have different code <-> physical scaling.
536        The individual factor is obtained by multiplying with "unitSI" reading getting
537        data from disk.
538        """
539        setdefaultattr(self, "length_unit", self.quan(1.0, "m"))
540        setdefaultattr(self, "mass_unit", self.quan(1.0, "kg"))
541        setdefaultattr(self, "time_unit", self.quan(1.0, "s"))
542        setdefaultattr(self, "velocity_unit", self.quan(1.0, "m/s"))
543        setdefaultattr(self, "magnetic_unit", self.quan(1.0, "T"))
544
545    def _parse_parameter_file(self):
546        """Read in metadata describing the overall data on-disk."""
547        f = self._handle
548        bp = self.base_path
549        mp = self.meshes_path
550
551        self.unique_identifier = 0
552        self.parameters = 0
553        self._periodicity = np.zeros(3, dtype="bool")
554        self.refine_by = 1
555        self.cosmological_simulation = 0
556
557        try:
558            shapes = {}
559            left_edges = {}
560            right_edges = {}
561            meshes = f[bp + mp]
562            for mname in meshes.keys():
563                mesh = meshes[mname]
564                if isinstance(mesh, h5py.Group):
565                    shape = np.asarray(mesh[list(mesh.keys())[0]].shape)
566                else:
567                    shape = np.asarray(mesh.shape)
568                spacing = np.asarray(mesh.attrs["gridSpacing"])
569                offset = np.asarray(mesh.attrs["gridGlobalOffset"])
570                unit_si = np.asarray(mesh.attrs["gridUnitSI"])
571                le = offset * unit_si
572                re = le + shape * unit_si * spacing
573                shapes[mname] = shape
574                left_edges[mname] = le
575                right_edges[mname] = re
576            lowest_dim = np.min([len(i) for i in shapes.values()])
577            shapes = np.asarray([i[:lowest_dim] for i in shapes.values()])
578            left_edges = np.asarray([i[:lowest_dim] for i in left_edges.values()])
579            right_edges = np.asarray([i[:lowest_dim] for i in right_edges.values()])
580            fs = []
581            dle = []
582            dre = []
583            for i in np.arange(lowest_dim):
584                fs.append(np.max(shapes.transpose()[i]))
585                dle.append(np.min(left_edges.transpose()[i]))
586                dre.append(np.min(right_edges.transpose()[i]))
587            self.dimensionality = len(fs)
588            self.domain_dimensions = np.append(fs, np.ones(3 - self.dimensionality))
589            self.domain_left_edge = np.append(dle, np.zeros(3 - len(dle)))
590            self.domain_right_edge = np.append(dre, np.ones(3 - len(dre)))
591        except (KeyError, TypeError, AttributeError):
592            if self.standard_version <= Version("1.1.0"):
593                self.dimensionality = 3
594                self.domain_dimensions = np.ones(3, dtype=np.float64)
595                self.domain_left_edge = np.zeros(3, dtype=np.float64)
596                self.domain_right_edge = np.ones(3, dtype=np.float64)
597            else:
598                raise
599
600        self.current_time = f[bp].attrs["time"] * f[bp].attrs["timeUnitSI"]
601
602    @classmethod
603    def _is_valid(cls, filename, *args, **kwargs):
604        """Checks whether the supplied file can be read by this frontend."""
605        warn_h5py(filename)
606        try:
607            with h5py.File(filename, mode="r") as f:
608                attrs = list(f["/"].attrs.keys())
609                for i in opmd_required_attributes:
610                    if i not in attrs:
611                        return False
612
613                if Version(f.attrs["openPMD"].decode()) not in ompd_known_versions:
614                    return False
615
616                if f.attrs["iterationEncoding"].decode() == "fileBased":
617                    return True
618
619                return False
620        except (OSError, ImportError):
621            return False
622
623
624class OpenPMDDatasetSeries(DatasetSeries):
625    _pre_outputs = ()
626    _dataset_cls = OpenPMDDataset
627    parallel = True
628    setup_function = None
629    mixed_dataset_types = False
630
631    def __init__(self, filename):
632        super().__init__([])
633        self.handle = h5py.File(filename, mode="r")
634        self.filename = filename
635        self._pre_outputs = sorted(
636            np.asarray(list(self.handle["/data"].keys()), dtype="int64")
637        )
638
639    def __iter__(self):
640        for it in self._pre_outputs:
641            ds = self._load(it, **self.kwargs)
642            self._setup_function(ds)
643            yield ds
644
645    def __getitem__(self, key):
646        if isinstance(key, int):
647            o = self._load(key)
648            self._setup_function(o)
649            return o
650        else:
651            raise KeyError(f"Unknown iteration {key}")
652
653    def _load(self, it, **kwargs):
654        return OpenPMDDataset(self.filename, iteration=it)
655
656
657class OpenPMDGroupBasedDataset(Dataset):
658    _index_class = OpenPMDHierarchy
659    _field_info_class = OpenPMDFieldInfo
660
661    def __new__(cls, filename, *args, **kwargs):
662        ret = object.__new__(OpenPMDDatasetSeries)
663        ret.__init__(filename)
664        return ret
665
666    @classmethod
667    def _is_valid(cls, filename, *args, **kwargs):
668        warn_h5py(filename)
669        try:
670            with h5py.File(filename, mode="r") as f:
671                attrs = list(f["/"].attrs.keys())
672                for i in opmd_required_attributes:
673                    if i not in attrs:
674                        return False
675
676                if Version(f.attrs["openPMD"].decode()) not in ompd_known_versions:
677                    return False
678
679                if f.attrs["iterationEncoding"].decode() == "groupBased":
680                    return True
681
682                return False
683        except (OSError, ImportError):
684            return False
685