1import glob
2import os
3
4import numpy as np
5from unyt import dimensions, unyt_array
6from unyt.unit_registry import UnitRegistry
7
8from yt.data_objects.time_series import DatasetSeries, SimulationTimeSeries
9from yt.funcs import only_on_root
10from yt.loaders import load
11from yt.utilities.cosmology import Cosmology
12from yt.utilities.exceptions import (
13    InvalidSimulationTimeSeries,
14    MissingParameter,
15    NoStoppingCondition,
16    YTUnidentifiedDataType,
17)
18from yt.utilities.logger import ytLogger as mylog
19from yt.utilities.parallel_tools.parallel_analysis_interface import parallel_objects
20
21
22class GadgetSimulation(SimulationTimeSeries):
23    r"""
24    Initialize an Gadget Simulation object.
25
26    Upon creation, the parameter file is parsed and the time and redshift
27    are calculated and stored in all_outputs.  A time units dictionary is
28    instantiated to allow for time outputs to be requested with physical
29    time units.  The get_time_series can be used to generate a
30    DatasetSeries object.
31
32    parameter_filename : str
33        The simulation parameter file.
34    find_outputs : bool
35        If True, the OutputDir directory is searched for datasets.
36        Time and redshift information are gathered by temporarily
37        instantiating each dataset.  This can be used when simulation
38        data was created in a non-standard way, making it difficult
39        to guess the corresponding time and redshift information.
40        Default: False.
41
42    Examples
43    --------
44    >>> import yt
45    >>> gs = yt.load_simulation("my_simulation.par", "Gadget")
46    >>> gs.get_time_series()
47    >>> for ds in gs:
48    ...     print(ds.current_time)
49
50    """
51
52    def __init__(self, parameter_filename, find_outputs=False):
53        self.simulation_type = "particle"
54        self.dimensionality = 3
55        SimulationTimeSeries.__init__(
56            self, parameter_filename, find_outputs=find_outputs
57        )
58
59    def _set_units(self):
60        self.unit_registry = UnitRegistry()
61        self.time_unit = self.quan(1.0, "s")
62        if self.cosmological_simulation:
63            # Instantiate Cosmology object for units and time conversions.
64            self.cosmology = Cosmology(
65                hubble_constant=self.hubble_constant,
66                omega_matter=self.omega_matter,
67                omega_lambda=self.omega_lambda,
68                unit_registry=self.unit_registry,
69            )
70            if "h" in self.unit_registry:
71                self.unit_registry.modify("h", self.hubble_constant)
72            else:
73                self.unit_registry.add(
74                    "h", self.hubble_constant, dimensions.dimensionless
75                )
76            # Comoving lengths
77            for my_unit in ["m", "pc", "AU"]:
78                new_unit = f"{my_unit}cm"
79                # technically not true, but should be ok
80                self.unit_registry.add(
81                    new_unit,
82                    self.unit_registry.lut[my_unit][0],
83                    dimensions.length,
84                    "\\rm{%s}/(1+z)" % my_unit,
85                    prefixable=True,
86                )
87            self.length_unit = self.quan(
88                self.unit_base["UnitLength_in_cm"],
89                "cmcm / h",
90                registry=self.unit_registry,
91            )
92            self.mass_unit = self.quan(
93                self.unit_base["UnitMass_in_g"], "g / h", registry=self.unit_registry
94            )
95            self.box_size = self.box_size * self.length_unit
96            self.domain_left_edge = self.domain_left_edge * self.length_unit
97            self.domain_right_edge = self.domain_right_edge * self.length_unit
98            self.unit_registry.add(
99                "unitary",
100                float(self.box_size.in_base()),
101                self.length_unit.units.dimensions,
102            )
103        else:
104            # Read time from file for non-cosmological sim
105            self.time_unit = self.quan(
106                self.unit_base["UnitLength_in_cm"]
107                / self.unit_base["UnitVelocity_in_cm_per_s"],
108                "s",
109            )
110            self.unit_registry.add("code_time", 1.0, dimensions.time)
111            self.unit_registry.modify("code_time", self.time_unit)
112            # Length
113            self.length_unit = self.quan(self.unit_base["UnitLength_in_cm"], "cm")
114            self.unit_registry.add("code_length", 1.0, dimensions.length)
115            self.unit_registry.modify("code_length", self.length_unit)
116
117    def get_time_series(
118        self,
119        initial_time=None,
120        final_time=None,
121        initial_redshift=None,
122        final_redshift=None,
123        times=None,
124        redshifts=None,
125        tolerance=None,
126        parallel=True,
127        setup_function=None,
128    ):
129
130        """
131        Instantiate a DatasetSeries object for a set of outputs.
132
133        If no additional keywords given, a DatasetSeries object will be
134        created with all potential datasets created by the simulation.
135
136        Outputs can be gather by specifying a time or redshift range
137        (or combination of time and redshift), with a specific list of
138        times or redshifts), or by simply searching all subdirectories
139        within the simulation directory.
140
141        initial_time : tuple of type (float, str)
142            The earliest time for outputs to be included.  This should be
143            given as the value and the string representation of the units.
144            For example, (5.0, "Gyr").  If None, the initial time of the
145            simulation is used.  This can be used in combination with
146            either final_time or final_redshift.
147            Default: None.
148        final_time : tuple of type (float, str)
149            The latest time for outputs to be included.  This should be
150            given as the value and the string representation of the units.
151            For example, (13.7, "Gyr"). If None, the final time of the
152            simulation is used.  This can be used in combination with either
153            initial_time or initial_redshift.
154            Default: None.
155        times : tuple of type (float array, str)
156            A list of times for which outputs will be found and the units
157            of those values.  For example, ([0, 1, 2, 3], "s").
158            Default: None.
159        initial_redshift : float
160            The earliest redshift for outputs to be included.  If None,
161            the initial redshift of the simulation is used.  This can be
162            used in combination with either final_time or
163            final_redshift.
164            Default: None.
165        final_redshift : float
166            The latest redshift for outputs to be included.  If None,
167            the final redshift of the simulation is used.  This can be
168            used in combination with either initial_time or
169            initial_redshift.
170            Default: None.
171        redshifts : array_like
172            A list of redshifts for which outputs will be found.
173            Default: None.
174        tolerance : float
175            Used in combination with "times" or "redshifts" keywords,
176            this is the tolerance within which outputs are accepted
177            given the requested times or redshifts.  If None, the
178            nearest output is always taken.
179            Default: None.
180        parallel : bool/int
181            If True, the generated DatasetSeries will divide the work
182            such that a single processor works on each dataset.  If an
183            integer is supplied, the work will be divided into that
184            number of jobs.
185            Default: True.
186        setup_function : callable, accepts a ds
187            This function will be called whenever a dataset is loaded.
188
189        Examples
190        --------
191
192        >>> import yt
193        >>> gs = yt.load_simulation("my_simulation.par", "Gadget")
194
195        >>> gs.get_time_series(initial_redshift=10, final_time=(13.7, "Gyr"))
196
197        >>> gs.get_time_series(redshifts=[3, 2, 1, 0])
198
199        >>> # after calling get_time_series
200        >>> for ds in gs.piter():
201        ...     p = ProjectionPlot(ds, "x", ("gas", "density"))
202        ...     p.save()
203
204        >>> # An example using the setup_function keyword
205        >>> def print_time(ds):
206        ...     print(ds.current_time)
207        >>> gs.get_time_series(setup_function=print_time)
208        >>> for ds in gs:
209        ...     SlicePlot(ds, "x", "Density").save()
210
211        """
212
213        if (
214            initial_redshift is not None or final_redshift is not None
215        ) and not self.cosmological_simulation:
216            raise InvalidSimulationTimeSeries(
217                "An initial or final redshift has been given for a "
218                + "noncosmological simulation."
219            )
220
221        my_all_outputs = self.all_outputs
222        if not my_all_outputs:
223            DatasetSeries.__init__(
224                self, outputs=[], parallel=parallel, unit_base=self.unit_base
225            )
226            mylog.info("0 outputs loaded into time series.")
227            return
228
229        # Apply selection criteria to the set.
230        if times is not None:
231            my_outputs = self._get_outputs_by_key(
232                "time", times, tolerance=tolerance, outputs=my_all_outputs
233            )
234
235        elif redshifts is not None:
236            my_outputs = self._get_outputs_by_key(
237                "redshift", redshifts, tolerance=tolerance, outputs=my_all_outputs
238            )
239
240        else:
241            if initial_time is not None:
242                if isinstance(initial_time, float):
243                    initial_time = self.quan(initial_time, "code_time")
244                elif isinstance(initial_time, tuple) and len(initial_time) == 2:
245                    initial_time = self.quan(*initial_time)
246                elif not isinstance(initial_time, unyt_array):
247                    raise RuntimeError(
248                        "Error: initial_time must be given as a float or "
249                        + "tuple of (value, units)."
250                    )
251            elif initial_redshift is not None:
252                my_initial_time = self.cosmology.t_from_z(initial_redshift)
253            else:
254                my_initial_time = self.initial_time
255
256            if final_time is not None:
257                if isinstance(final_time, float):
258                    final_time = self.quan(final_time, "code_time")
259                elif isinstance(final_time, tuple) and len(final_time) == 2:
260                    final_time = self.quan(*final_time)
261                elif not isinstance(final_time, unyt_array):
262                    raise RuntimeError(
263                        "Error: final_time must be given as a float or "
264                        + "tuple of (value, units)."
265                    )
266                my_final_time = final_time.in_units("s")
267            elif final_redshift is not None:
268                my_final_time = self.cosmology.t_from_z(final_redshift)
269            else:
270                my_final_time = self.final_time
271
272            my_initial_time.convert_to_units("s")
273            my_final_time.convert_to_units("s")
274            my_times = np.array([a["time"] for a in my_all_outputs])
275            my_indices = np.digitize([my_initial_time, my_final_time], my_times)
276            if my_initial_time == my_times[my_indices[0] - 1]:
277                my_indices[0] -= 1
278            my_outputs = my_all_outputs[my_indices[0] : my_indices[1]]
279
280        init_outputs = []
281        for output in my_outputs:
282            if os.path.exists(output["filename"]):
283                init_outputs.append(output["filename"])
284        if len(init_outputs) == 0 and len(my_outputs) > 0:
285            mylog.warning(
286                "Could not find any datasets.  "
287                "Check the value of OutputDir in your parameter file."
288            )
289
290        DatasetSeries.__init__(
291            self,
292            outputs=init_outputs,
293            parallel=parallel,
294            setup_function=setup_function,
295            unit_base=self.unit_base,
296        )
297        mylog.info("%d outputs loaded into time series.", len(init_outputs))
298
299    def _parse_parameter_file(self):
300        """
301        Parses the parameter file and establishes the various
302        dictionaries.
303        """
304
305        self.unit_base = {}
306
307        # Let's read the file
308        lines = open(self.parameter_filename).readlines()
309        comments = ["%", ";"]
310        for line in (l.strip() for l in lines):
311            for comment in comments:
312                if comment in line:
313                    line = line[0 : line.find(comment)]
314            if len(line) < 2:
315                continue
316            param, vals = (i.strip() for i in line.split(None, 1))
317            # First we try to decipher what type of value it is.
318            vals = vals.split()
319            # Special case approaching.
320            if "(do" in vals:
321                vals = vals[:1]
322            if len(vals) == 0:
323                pcast = str  # Assume NULL output
324            else:
325                v = vals[0]
326                # Figure out if it's castable to floating point:
327                try:
328                    float(v)
329                except ValueError:
330                    pcast = str
331                else:
332                    if any("." in v or "e" in v for v in vals):
333                        pcast = float
334                    elif v == "inf":
335                        pcast = str
336                    else:
337                        pcast = int
338            # Now we figure out what to do with it.
339            if param.startswith("Unit"):
340                self.unit_base[param] = float(vals[0])
341            if len(vals) == 0:
342                vals = ""
343            elif len(vals) == 1:
344                vals = pcast(vals[0])
345            else:
346                vals = np.array([pcast(i) for i in vals])
347
348            self.parameters[param] = vals
349
350        # Domain dimensions for Gadget datasets are always 2x2x2 for octree
351        self.domain_dimensions = np.array([2, 2, 2])
352
353        if self.parameters["ComovingIntegrationOn"]:
354            cosmo_attr = {
355                "box_size": "BoxSize",
356                "omega_lambda": "OmegaLambda",
357                "omega_matter": "Omega0",
358                "hubble_constant": "HubbleParam",
359            }
360            self.initial_redshift = 1.0 / self.parameters["TimeBegin"] - 1.0
361            self.final_redshift = 1.0 / self.parameters["TimeMax"] - 1.0
362            self.cosmological_simulation = 1
363            for a, v in cosmo_attr.items():
364                if v not in self.parameters:
365                    raise MissingParameter(self.parameter_filename, v)
366                setattr(self, a, self.parameters[v])
367            self.domain_left_edge = np.array([0.0, 0.0, 0.0])
368            self.domain_right_edge = (
369                np.array([1.0, 1.0, 1.0]) * self.parameters["BoxSize"]
370            )
371        else:
372            self.cosmological_simulation = 0
373            self.omega_lambda = self.omega_matter = self.hubble_constant = 0.0
374
375    def _find_data_dir(self):
376        """
377        Find proper location for datasets.  First look where parameter file
378        points, but if this doesn't exist then default to the current
379        directory.
380        """
381        if self.parameters["OutputDir"].startswith("/"):
382            data_dir = self.parameters["OutputDir"]
383        else:
384            data_dir = os.path.join(self.directory, self.parameters["OutputDir"])
385        if not os.path.exists(data_dir):
386            mylog.info(
387                "OutputDir not found at %s, instead using %s.", data_dir, self.directory
388            )
389            data_dir = self.directory
390        self.data_dir = data_dir
391
392    def _snapshot_format(self, index=None):
393        """
394        The snapshot filename for a given index.  Modify this for different
395        naming conventions.
396        """
397
398        if self.parameters["NumFilesPerSnapshot"] > 1:
399            suffix = ".0"
400        else:
401            suffix = ""
402        if self.parameters["SnapFormat"] == 3:
403            suffix += ".hdf5"
404        if index is None:
405            count = "*"
406        else:
407            count = "%03d" % index
408        filename = f"{self.parameters['SnapshotFileBase']}_{count}{suffix}"
409        return os.path.join(self.data_dir, filename)
410
411    def _get_all_outputs(self, find_outputs=False):
412        """
413        Get all potential datasets and combine into a time-sorted list.
414        """
415
416        # Find the data directory where the outputs are
417        self._find_data_dir()
418
419        # Create the set of outputs from which further selection will be done.
420        if find_outputs:
421            self._find_outputs()
422        else:
423            if self.parameters["OutputListOn"]:
424                a_values = [
425                    float(a)
426                    for a in open(
427                        os.path.join(
428                            self.data_dir, self.parameters["OutputListFilename"]
429                        ),
430                    ).readlines()
431                ]
432            else:
433                a_values = [float(self.parameters["TimeOfFirstSnapshot"])]
434                time_max = float(self.parameters["TimeMax"])
435                while a_values[-1] < time_max:
436                    if self.cosmological_simulation:
437                        a_values.append(
438                            a_values[-1] * self.parameters["TimeBetSnapshot"]
439                        )
440                    else:
441                        a_values.append(
442                            a_values[-1] + self.parameters["TimeBetSnapshot"]
443                        )
444                if a_values[-1] > time_max:
445                    a_values[-1] = time_max
446
447            if self.cosmological_simulation:
448                self.all_outputs = [
449                    {"filename": self._snapshot_format(i), "redshift": (1.0 / a - 1)}
450                    for i, a in enumerate(a_values)
451                ]
452
453                # Calculate times for redshift outputs.
454                for output in self.all_outputs:
455                    output["time"] = self.cosmology.t_from_z(output["redshift"])
456            else:
457                self.all_outputs = [
458                    {
459                        "filename": self._snapshot_format(i),
460                        "time": self.quan(a, "code_time"),
461                    }
462                    for i, a in enumerate(a_values)
463                ]
464
465            self.all_outputs.sort(key=lambda obj: obj["time"].to_ndarray())
466
467    def _calculate_simulation_bounds(self):
468        """
469        Figure out the starting and stopping time and redshift for the simulation.
470        """
471
472        # Convert initial/final redshifts to times.
473        if self.cosmological_simulation:
474            self.initial_time = self.cosmology.t_from_z(self.initial_redshift)
475            self.initial_time.units.registry = self.unit_registry
476            self.final_time = self.cosmology.t_from_z(self.final_redshift)
477            self.final_time.units.registry = self.unit_registry
478
479        # If not a cosmology simulation, figure out the stopping criteria.
480        else:
481            if "TimeBegin" in self.parameters:
482                self.initial_time = self.quan(self.parameters["TimeBegin"], "code_time")
483            else:
484                self.initial_time = self.quan(0.0, "code_time")
485
486            if "TimeMax" in self.parameters:
487                self.final_time = self.quan(self.parameters["TimeMax"], "code_time")
488            else:
489                self.final_time = None
490            if "TimeMax" not in self.parameters:
491                raise NoStoppingCondition(self.parameter_filename)
492
493    def _find_outputs(self):
494        """
495        Search for directories matching the data dump keywords.
496        If found, get dataset times py opening the ds.
497        """
498        potential_outputs = glob.glob(self._snapshot_format())
499        self.all_outputs = self._check_for_outputs(potential_outputs)
500        self.all_outputs.sort(key=lambda obj: obj["time"])
501        only_on_root(mylog.info, "Located %d total outputs.", len(self.all_outputs))
502
503        # manually set final time and redshift with last output
504        if self.all_outputs:
505            self.final_time = self.all_outputs[-1]["time"]
506            if self.cosmological_simulation:
507                self.final_redshift = self.all_outputs[-1]["redshift"]
508
509    def _check_for_outputs(self, potential_outputs):
510        r"""
511        Check a list of files to see if they are valid datasets.
512        """
513
514        only_on_root(
515            mylog.info, "Checking %d potential outputs.", len(potential_outputs)
516        )
517
518        my_outputs = {}
519        for my_storage, output in parallel_objects(
520            potential_outputs, storage=my_outputs
521        ):
522            try:
523                ds = load(output)
524            except (FileNotFoundError, YTUnidentifiedDataType):
525                mylog.error("Failed to load %s", output)
526                continue
527            my_storage.result = {
528                "filename": output,
529                "time": ds.current_time.in_units("s"),
530            }
531            if ds.cosmological_simulation:
532                my_storage.result["redshift"] = ds.current_redshift
533
534        my_outputs = [
535            my_output for my_output in my_outputs.values() if my_output is not None
536        ]
537        return my_outputs
538
539    def _write_cosmology_outputs(self, filename, outputs, start_index, decimals=3):
540        r"""
541        Write cosmology output parameters for a cosmology splice.
542        """
543
544        mylog.info("Writing redshift output list to %s.", filename)
545        f = open(filename, "w")
546        for output in outputs:
547            f.write(f"{1.0 / (1.0 + output['redshift']):f}\n")
548        f.close()
549