1import inspect
2
3from yt.utilities.object_registries import analysis_task_registry
4
5
6class AnalysisTask:
7    def __init_subclass__(cls, *args, **kwargs):
8        if hasattr(cls, "skip") and not cls.skip:
9            return
10        analysis_task_registry[cls.__name__] = cls
11
12    def __init__(self, *args, **kwargs):
13        # This should only get called if the subclassed object
14        # does not override
15        if len(args) + len(kwargs) != len(self._params):
16            raise RuntimeError
17        self.__dict__.update(zip(self._params, args))
18        self.__dict__.update(kwargs)
19
20    def __repr__(self):
21        # Stolen from YTDataContainer.__repr__
22        s = f"{self.__class__.__name__}: "
23        s += ", ".join(f"{i}={getattr(self, i)}" for i in self._params)
24        return s
25
26
27def analysis_task(params=None):
28    if params is None:
29        params = tuple()
30
31    def create_new_class(func):
32        cls = type(func.__name__, (AnalysisTask,), dict(eval=func, _params=params))
33        return cls
34
35    return create_new_class
36
37
38@analysis_task(("field",))
39def MaximumValue(params, data_object):
40    v = data_object.quantities["MaxLocation"](params.field)[0]
41    return v
42
43
44@analysis_task()
45def CurrentTimeYears(params, ds):
46    return ds.current_time * ds["years"]
47
48
49class SlicePlotDataset(AnalysisTask):
50    _params = ["field", "axis", "center"]
51
52    def __init__(self, *args, **kwargs):
53        from yt.visualization.api import SlicePlot
54
55        self.SlicePlot = SlicePlot
56        AnalysisTask.__init__(self, *args, **kwargs)
57
58    def eval(self, ds):
59        slc = self.SlicePlot(ds, self.axis, self.field, center=self.center)
60        return slc.save()
61
62
63class QuantityProxy(AnalysisTask):
64    _params = None
65    quantity_name = None
66
67    def __repr__(self):
68        # Stolen from YTDataContainer.__repr__
69        s = f"{self.__class__.__name__}: "
70        s += ", ".join(["%s" % [arg for arg in self.args]])
71        s += ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
72        return s
73
74    def __init__(self, *args, **kwargs):
75        self.args = args
76        self.kwargs = kwargs
77
78    def eval(self, data_object):
79        rv = data_object.quantities[self.quantity_name](*self.args, **self.kwargs)
80        return rv
81
82
83class ParameterValue(AnalysisTask):
84    _params = ["parameter"]
85
86    def __init__(self, parameter, cast=None):
87        self.parameter = parameter
88        if cast is None:
89
90            def _identity(x):
91                return x
92
93            cast = _identity
94        self.cast = cast
95
96    def eval(self, ds):
97        return self.cast(ds.get_parameter(self.parameter))
98
99
100def create_quantity_proxy(quantity_object):
101    args, varargs, kwargs, defaults = inspect.getargspec(quantity_object[1])
102    # Strip off 'data' which is on every quantity function
103    params = args[1:]
104    if kwargs is not None:
105        params += kwargs
106    dd = dict(_params=params, quantity_name=quantity_object[0])
107    cls = type(quantity_object[0], (QuantityProxy,), dd)
108    return cls
109