1import functools
2import glob
3import inspect
4import os
5import weakref
6from functools import wraps
7
8import numpy as np
9from more_itertools import always_iterable
10
11from yt._maintenance.deprecation import issue_deprecation_warning
12from yt.config import ytcfg
13from yt.data_objects.analyzer_objects import AnalysisTask, create_quantity_proxy
14from yt.data_objects.particle_trajectories import ParticleTrajectories
15from yt.funcs import is_sequence, mylog
16from yt.units.yt_array import YTArray, YTQuantity
17from yt.utilities.exceptions import YTException
18from yt.utilities.object_registries import (
19    analysis_task_registry,
20    data_object_registry,
21    derived_quantity_registry,
22    simulation_time_series_registry,
23)
24from yt.utilities.parallel_tools.parallel_analysis_interface import (
25    communication_system,
26    parallel_objects,
27    parallel_root_only,
28)
29
30
31class AnalysisTaskProxy:
32    def __init__(self, time_series):
33        self.time_series = time_series
34
35    def __getitem__(self, key):
36        task_cls = analysis_task_registry[key]
37
38        @wraps(task_cls.__init__)
39        def func(*args, **kwargs):
40            task = task_cls(*args, **kwargs)
41            return self.time_series.eval(task)
42
43        return func
44
45    def keys(self):
46        return analysis_task_registry.keys()
47
48    def __contains__(self, key):
49        return key in analysis_task_registry
50
51
52def get_ds_prop(propname):
53    def _eval(params, ds):
54        return getattr(ds, propname)
55
56    cls = type(propname, (AnalysisTask,), dict(eval=_eval, _params=tuple()))
57    return cls
58
59
60attrs = (
61    "refine_by",
62    "dimensionality",
63    "current_time",
64    "domain_dimensions",
65    "domain_left_edge",
66    "domain_right_edge",
67    "unique_identifier",
68    "current_redshift",
69    "cosmological_simulation",
70    "omega_matter",
71    "omega_lambda",
72    "omega_radiation",
73    "hubble_constant",
74)
75
76
77class TimeSeriesParametersContainer:
78    def __init__(self, data_object):
79        self.data_object = data_object
80
81    def __getattr__(self, attr):
82        if attr in attrs:
83            return self.data_object.eval(get_ds_prop(attr)())
84        raise AttributeError(attr)
85
86
87class DatasetSeries:
88    r"""The DatasetSeries object is a container of multiple datasets,
89    allowing easy iteration and computation on them.
90
91    DatasetSeries objects are designed to provide easy ways to access,
92    analyze, parallelize and visualize multiple datasets sequentially.  This is
93    primarily expressed through iteration, but can also be constructed via
94    analysis tasks (see :ref:`time-series-analysis`).
95
96    Note that contained datasets are lazily loaded and weakly referenced. This means
97    that in order to perform follow-up operations on data it's best to define handles on
98    these datasets during iteration.
99
100    Parameters
101    ----------
102    outputs : list of filenames, or pattern
103        A list of filenames, for instance ["DD0001/DD0001", "DD0002/DD0002"],
104        or a glob pattern (i.e. containing wildcards '[]?!*') such as "DD*/DD*.index".
105        In the latter case, results are sorted automatically.
106        Filenames and patterns can be of type str, os.Pathlike or bytes.
107    parallel : True, False or int
108        This parameter governs the behavior when .piter() is called on the
109        resultant DatasetSeries object.  If this is set to False, the time
110        series will not iterate in parallel when .piter() is called.  If
111        this is set to either True, one processor will be allocated for
112        each iteration of the loop. If this is set to an integer, the loop
113        will be parallelized over this many workgroups. It the integer
114        value is less than the total number of available processors,
115        more than one processor will be allocated to a given loop iteration,
116        causing the functionality within the loop to be run in parallel.
117    setup_function : callable, accepts a ds
118        This function will be called whenever a dataset is loaded.
119    mixed_dataset_types : True or False, default False
120        Set to True if the DatasetSeries will load different dataset types, set
121        to False if loading dataset of a single type as this will result in a
122        considerable speed up from not having to figure out the dataset type.
123
124    Examples
125    --------
126
127    >>> ts = DatasetSeries(
128    ...     "GasSloshingLowRes/sloshing_low_res_hdf5_plt_cnt_0[0-6][0-9]0"
129    ... )
130    >>> for ds in ts:
131    ...     SlicePlot(ds, "x", ("gas", "density")).save()
132    ...
133    >>> def print_time(ds):
134    ...     print(ds.current_time)
135    ...
136    >>> ts = DatasetSeries(
137    ...     "GasSloshingLowRes/sloshing_low_res_hdf5_plt_cnt_0[0-6][0-9]0",
138    ...     setup_function=print_time,
139    ... )
140    ...
141    >>> for ds in ts:
142    ...     SlicePlot(ds, "x", ("gas", "density")).save()
143
144    """
145
146    def __init_subclass__(cls, *args, **kwargs):
147        super().__init_subclass__(*args, **kwargs)
148        code_name = cls.__name__[: cls.__name__.find("Simulation")]
149        if code_name:
150            simulation_time_series_registry[code_name] = cls
151            mylog.debug("Registering simulation: %s as %s", code_name, cls)
152
153    def __new__(cls, outputs, *args, **kwargs):
154        try:
155            outputs = cls._get_filenames_from_glob_pattern(outputs)
156        except TypeError:
157            pass
158        ret = super().__new__(cls)
159        ret._pre_outputs = outputs[:]
160        return ret
161
162    def __init__(
163        self,
164        outputs,
165        parallel=True,
166        setup_function=None,
167        mixed_dataset_types=False,
168        **kwargs,
169    ):
170        # This is needed to properly set _pre_outputs for Simulation subclasses.
171        self._mixed_dataset_types = mixed_dataset_types
172        if is_sequence(outputs) and not isinstance(outputs, str):
173            self._pre_outputs = outputs[:]
174        self.tasks = AnalysisTaskProxy(self)
175        self.params = TimeSeriesParametersContainer(self)
176        if setup_function is None:
177
178            def _null(x):
179                return None
180
181            setup_function = _null
182        self._setup_function = setup_function
183        for type_name in data_object_registry:
184            setattr(
185                self, type_name, functools.partial(DatasetSeriesObject, self, type_name)
186            )
187        self.parallel = parallel
188        self.kwargs = kwargs
189
190    @staticmethod
191    def _get_filenames_from_glob_pattern(outputs):
192        """
193        Helper function to DatasetSeries.__new__
194        handle a special case where "outputs" is assumed to be really a pattern string
195        """
196        pattern = outputs
197        epattern = os.path.expanduser(pattern)
198        data_dir = ytcfg.get("yt", "test_data_dir")
199        # if no match if found from the current work dir,
200        # we try to match the pattern from the test data dir
201        file_list = glob.glob(epattern) or glob.glob(os.path.join(data_dir, epattern))
202        if not file_list:
203            raise FileNotFoundError(f"No match found for pattern : {pattern}")
204        return sorted(file_list)
205
206    def __getitem__(self, key):
207        if isinstance(key, slice):
208            if isinstance(key.start, float):
209                return self.get_range(key.start, key.stop)
210            # This will return a sliced up object!
211            return DatasetSeries(
212                self._pre_outputs[key], parallel=self.parallel, **self.kwargs
213            )
214        o = self._pre_outputs[key]
215        if isinstance(o, (str, os.PathLike)):
216            o = self._load(o, **self.kwargs)
217            self._setup_function(o)
218        return o
219
220    def __len__(self):
221        return len(self._pre_outputs)
222
223    @property
224    def outputs(self):
225        return self._pre_outputs
226
227    def piter(self, storage=None, dynamic=False):
228        r"""Iterate over time series components in parallel.
229
230        This allows you to iterate over a time series while dispatching
231        individual components of that time series to different processors or
232        processor groups.  If the parallelism strategy was set to be
233        multi-processor (by "parallel = N" where N is an integer when the
234        DatasetSeries was created) this will issue each dataset to an
235        N-processor group.  For instance, this would allow you to start a 1024
236        processor job, loading up 100 datasets in a time series and creating 8
237        processor groups of 128 processors each, each of which would be
238        assigned a different dataset.  This could be accomplished as shown in
239        the examples below.  The *storage* option is as seen in
240        :func:`~yt.utilities.parallel_tools.parallel_analysis_interface.parallel_objects`
241        which is a mechanism for storing results of analysis on an individual
242        dataset and then combining the results at the end, so that the entire
243        set of processors have access to those results.
244
245        Note that supplying a *store* changes the iteration mechanism; see
246        below.
247
248        Parameters
249        ----------
250        storage : dict
251            This is a dictionary, which will be filled with results during the
252            course of the iteration.  The keys will be the dataset
253            indices and the values will be whatever is assigned to the *result*
254            attribute on the storage during iteration.
255        dynamic : boolean
256            This governs whether or not dynamic load balancing will be
257            enabled.  This requires one dedicated processor; if this
258            is enabled with a set of 128 processors available, only
259            127 will be available to iterate over objects as one will
260            be load balancing the rest.
261
262
263        Examples
264        --------
265        Here is an example of iteration when the results do not need to be
266        stored.  One processor will be assigned to each dataset.
267
268        >>> ts = DatasetSeries("DD*/DD*.index")
269        >>> for ds in ts.piter():
270        ...     SlicePlot(ds, "x", ("gas", "density")).save()
271        ...
272
273        This demonstrates how one might store results:
274
275        >>> def print_time(ds):
276        ...     print(ds.current_time)
277        ...
278        >>> ts = DatasetSeries("DD*/DD*.index", setup_function=print_time)
279        ...
280        >>> my_storage = {}
281        >>> for sto, ds in ts.piter(storage=my_storage):
282        ...     v, c = ds.find_max(("gas", "density"))
283        ...     sto.result = (v, c)
284        ...
285        >>> for i, (v, c) in sorted(my_storage.items()):
286        ...     print("% 4i  %0.3e" % (i, v))
287        ...
288
289        This shows how to dispatch 4 processors to each dataset:
290
291        >>> ts = DatasetSeries("DD*/DD*.index", parallel=4)
292        >>> for ds in ts.piter():
293        ...     ProjectionPlot(ds, "x", ("gas", "density")).save()
294        ...
295
296        """
297        if not self.parallel:
298            njobs = 1
299        elif not dynamic:
300            if self.parallel:
301                njobs = -1
302            else:
303                njobs = self.parallel
304        else:
305            my_communicator = communication_system.communicators[-1]
306            nsize = my_communicator.size
307            if nsize == 1:
308                self.parallel = False
309                dynamic = False
310                njobs = 1
311            else:
312                njobs = nsize - 1
313
314        for output in parallel_objects(
315            self._pre_outputs, njobs=njobs, storage=storage, dynamic=dynamic
316        ):
317            if storage is not None:
318                sto, output = output
319
320            if isinstance(output, str):
321                ds = self._load(output, **self.kwargs)
322                self._setup_function(ds)
323            else:
324                ds = output
325
326            if storage is not None:
327                next_ret = (sto, ds)
328            else:
329                next_ret = ds
330
331            yield next_ret
332
333    def eval(self, tasks, obj=None):
334        return_values = {}
335        for store, ds in self.piter(return_values):
336            store.result = []
337            for task in always_iterable(tasks):
338                try:
339                    style = inspect.getargspec(task.eval)[0][1]
340                    if style == "ds":
341                        arg = ds
342                    elif style == "data_object":
343                        if obj is None:
344                            obj = DatasetSeriesObject(self, "all_data")
345                        arg = obj.get(ds)
346                    rv = task.eval(arg)
347                # We catch and store YT-originating exceptions
348                # This fixes the standard problem of having a sphere that's too
349                # small.
350                except YTException:
351                    pass
352                store.result.append(rv)
353        return [v for k, v in sorted(return_values.items())]
354
355    @classmethod
356    def from_filenames(cls, filenames, parallel=True, setup_function=None, **kwargs):
357        r"""Create a time series from either a filename pattern or a list of
358        filenames.
359
360        This method provides an easy way to create a
361        :class:`~yt.data_objects.time_series.DatasetSeries`, given a set of
362        filenames or a pattern that matches them.  Additionally, it can set the
363        parallelism strategy.
364
365        Parameters
366        ----------
367        filenames : list or pattern
368            This can either be a list of filenames (such as ["DD0001/DD0001",
369            "DD0002/DD0002"]) or a pattern to match, such as
370            "DD*/DD*.index").  If it's the former, they will be loaded in
371            order.  The latter will be identified with the glob module and then
372            sorted.
373        parallel : True, False or int
374            This parameter governs the behavior when .piter() is called on the
375            resultant DatasetSeries object.  If this is set to False, the time
376            series will not iterate in parallel when .piter() is called.  If
377            this is set to either True or an integer, it will be iterated with
378            1 or that integer number of processors assigned to each parameter
379            file provided to the loop.
380        setup_function : callable, accepts a ds
381            This function will be called whenever a dataset is loaded.
382
383        Examples
384        --------
385
386        >>> def print_time(ds):
387        ...     print(ds.current_time)
388        ...
389        >>> ts = DatasetSeries.from_filenames(
390        ...     "GasSloshingLowRes/sloshing_low_res_hdf5_plt_cnt_0[0-6][0-9]0",
391        ...     setup_function=print_time,
392        ... )
393        ...
394        >>> for ds in ts:
395        ...     SlicePlot(ds, "x", ("gas", "density")).save()
396
397        """
398        issue_deprecation_warning(
399            "DatasetSeries.from_filenames() is deprecated and will be removed "
400            "in a future version of yt. Use DatasetSeries() directly.",
401            since="4.0.0",
402            removal="4.1.0",
403        )
404        obj = cls(filenames, parallel=parallel, setup_function=setup_function, **kwargs)
405        return obj
406
407    @classmethod
408    def from_output_log(cls, output_log, line_prefix="DATASET WRITTEN", parallel=True):
409        filenames = []
410        for line in open(output_log):
411            if not line.startswith(line_prefix):
412                continue
413            cut_line = line[len(line_prefix) :].strip()
414            fn = cut_line.split()[0]
415            filenames.append(fn)
416        obj = cls(filenames, parallel=parallel)
417        return obj
418
419    _dataset_cls = None
420
421    def _load(self, output_fn, **kwargs):
422        from yt.loaders import load
423
424        if self._dataset_cls is not None:
425            return self._dataset_cls(output_fn, **kwargs)
426        elif self._mixed_dataset_types:
427            return load(output_fn, **kwargs)
428        ds = load(output_fn, **kwargs)
429        self._dataset_cls = ds.__class__
430        return ds
431
432    def particle_trajectories(
433        self, indices, fields=None, suppress_logging=False, ptype=None
434    ):
435        r"""Create a collection of particle trajectories in time over a series of
436        datasets.
437
438        Parameters
439        ----------
440        indices : array_like
441            An integer array of particle indices whose trajectories we
442            want to track. If they are not sorted they will be sorted.
443        fields : list of strings, optional
444            A set of fields that is retrieved when the trajectory
445            collection is instantiated. Default: None (will default
446            to the fields 'particle_position_x', 'particle_position_y',
447            'particle_position_z')
448        suppress_logging : boolean
449            Suppress yt's logging when iterating over the simulation time
450            series. Default: False
451        ptype : str, optional
452            Only use this particle type. Default: None, which uses all particle type.
453
454        Examples
455        --------
456        >>> my_fns = glob.glob("orbit_hdf5_chk_00[0-9][0-9]")
457        >>> my_fns.sort()
458        >>> fields = [
459        ...     ("all", "particle_position_x"),
460        ...     ("all", "particle_position_y"),
461        ...     ("all", "particle_position_z"),
462        ...     ("all", "particle_velocity_x"),
463        ...     ("all", "particle_velocity_y"),
464        ...     ("all", "particle_velocity_z"),
465        ... ]
466        >>> ds = load(my_fns[0])
467        >>> init_sphere = ds.sphere(ds.domain_center, (0.5, "unitary"))
468        >>> indices = init_sphere[("all", "particle_index")].astype("int")
469        >>> ts = DatasetSeries(my_fns)
470        >>> trajs = ts.particle_trajectories(indices, fields=fields)
471        >>> for t in trajs:
472        ...     print(
473        ...         t[("all", "particle_velocity_x")].max(),
474        ...         t[("all", "particle_velocity_x")].min(),
475        ...     )
476
477        Notes
478        -----
479        This function will fail if there are duplicate particle ids or if some of the
480        particle disappear.
481        """
482        return ParticleTrajectories(
483            self, indices, fields=fields, suppress_logging=suppress_logging, ptype=ptype
484        )
485
486
487class TimeSeriesQuantitiesContainer:
488    def __init__(self, data_object, quantities):
489        self.data_object = data_object
490        self.quantities = quantities
491
492    def __getitem__(self, key):
493        if key not in self.quantities:
494            raise KeyError(key)
495        q = self.quantities[key]
496
497        def run_quantity_wrapper(quantity, quantity_name):
498            @wraps(derived_quantity_registry[quantity_name][1])
499            def run_quantity(*args, **kwargs):
500                to_run = quantity(*args, **kwargs)
501                return self.data_object.eval(to_run)
502
503            return run_quantity
504
505        return run_quantity_wrapper(q, key)
506
507
508class DatasetSeriesObject:
509    def __init__(self, time_series, data_object_name, *args, **kwargs):
510        self.time_series = weakref.proxy(time_series)
511        self.data_object_name = data_object_name
512        self._args = args
513        self._kwargs = kwargs
514        qs = {
515            qn: create_quantity_proxy(qv)
516            for qn, qv in derived_quantity_registry.items()
517        }
518        self.quantities = TimeSeriesQuantitiesContainer(self, qs)
519
520    def eval(self, tasks):
521        return self.time_series.eval(tasks, self)
522
523    def get(self, ds):
524        # We get the type name, which corresponds to an attribute of the
525        # index
526        cls = getattr(ds, self.data_object_name)
527        return cls(*self._args, **self._kwargs)
528
529
530class SimulationTimeSeries(DatasetSeries):
531    def __init__(self, parameter_filename, find_outputs=False):
532        """
533        Base class for generating simulation time series types.
534        Principally consists of a *parameter_filename*.
535        """
536
537        if not os.path.exists(parameter_filename):
538            raise FileNotFoundError(parameter_filename)
539        self.parameter_filename = parameter_filename
540        self.basename = os.path.basename(parameter_filename)
541        self.directory = os.path.dirname(parameter_filename)
542        self.parameters = {}
543        self.key_parameters = []
544
545        # Set some parameter defaults.
546        self._set_parameter_defaults()
547        # Read the simulation dataset.
548        self._parse_parameter_file()
549        # Set units
550        self._set_units()
551        # Figure out the starting and stopping times and redshift.
552        self._calculate_simulation_bounds()
553        # Get all possible datasets.
554        self._get_all_outputs(find_outputs=find_outputs)
555
556        self.print_key_parameters()
557
558    def _set_parameter_defaults(self):
559        pass
560
561    def _parse_parameter_file(self):
562        pass
563
564    def _set_units(self):
565        pass
566
567    def _calculate_simulation_bounds(self):
568        pass
569
570    def _get_all_outputs(**kwargs):
571        pass
572
573    def __repr__(self):
574        return self.parameter_filename
575
576    _arr = None
577
578    @property
579    def arr(self):
580        if self._arr is not None:
581            return self._arr
582        self._arr = functools.partial(YTArray, registry=self.unit_registry)
583        return self._arr
584
585    _quan = None
586
587    @property
588    def quan(self):
589        if self._quan is not None:
590            return self._quan
591        self._quan = functools.partial(YTQuantity, registry=self.unit_registry)
592        return self._quan
593
594    @parallel_root_only
595    def print_key_parameters(self):
596        """
597        Print out some key parameters for the simulation.
598        """
599        if self.simulation_type == "grid":
600            for a in ["domain_dimensions", "domain_left_edge", "domain_right_edge"]:
601                self._print_attr(a)
602        for a in ["initial_time", "final_time", "cosmological_simulation"]:
603            self._print_attr(a)
604        if getattr(self, "cosmological_simulation", False):
605            for a in [
606                "box_size",
607                "omega_matter",
608                "omega_lambda",
609                "omega_radiation",
610                "hubble_constant",
611                "initial_redshift",
612                "final_redshift",
613            ]:
614                self._print_attr(a)
615        for a in self.key_parameters:
616            self._print_attr(a)
617        mylog.info("Total datasets: %d.", len(self.all_outputs))
618
619    def _print_attr(self, a):
620        """
621        Print the attribute or warn about it missing.
622        """
623        if not hasattr(self, a):
624            mylog.error("Missing %s in dataset definition!", a)
625            return
626        v = getattr(self, a)
627        mylog.info("Parameters: %-25s = %s", a, v)
628
629    def _get_outputs_by_key(self, key, values, tolerance=None, outputs=None):
630        r"""
631        Get datasets at or near to given values.
632
633        Parameters
634        ----------
635        key : str
636            The key by which to retrieve outputs, usually 'time' or
637            'redshift'.
638        values : array_like
639            A list of values, given as floats.
640        tolerance : float
641            If not None, do not return a dataset unless the value is
642            within the tolerance value.  If None, simply return the
643            nearest dataset.
644            Default: None.
645        outputs : list
646            The list of outputs from which to choose.  If None,
647            self.all_outputs is used.
648            Default: None.
649
650        Examples
651        --------
652        >>> datasets = es.get_outputs_by_key("redshift", [0, 1, 2], tolerance=0.1)
653
654        """
655
656        if not isinstance(values, YTArray):
657            if isinstance(values, tuple) and len(values) == 2:
658                values = self.arr(*values)
659            else:
660                values = self.arr(values)
661        values = values.in_base()
662
663        if outputs is None:
664            outputs = self.all_outputs
665        my_outputs = []
666        if not outputs:
667            return my_outputs
668        for value in values:
669            outputs.sort(key=lambda obj: np.abs(value - obj[key]))
670            if (
671                tolerance is None or np.abs(value - outputs[0][key]) <= tolerance
672            ) and outputs[0] not in my_outputs:
673                my_outputs.append(outputs[0])
674            else:
675                mylog.error("No dataset added for %s = %f.", key, value)
676
677        outputs.sort(key=lambda obj: obj["time"])
678        return my_outputs
679