1import inspect 2 3from yt.utilities.object_registries import analysis_task_registry 4 5 6class AnalysisTask: 7 def __init_subclass__(cls, *args, **kwargs): 8 if hasattr(cls, "skip") and not cls.skip: 9 return 10 analysis_task_registry[cls.__name__] = cls 11 12 def __init__(self, *args, **kwargs): 13 # This should only get called if the subclassed object 14 # does not override 15 if len(args) + len(kwargs) != len(self._params): 16 raise RuntimeError 17 self.__dict__.update(zip(self._params, args)) 18 self.__dict__.update(kwargs) 19 20 def __repr__(self): 21 # Stolen from YTDataContainer.__repr__ 22 s = f"{self.__class__.__name__}: " 23 s += ", ".join(f"{i}={getattr(self, i)}" for i in self._params) 24 return s 25 26 27def analysis_task(params=None): 28 if params is None: 29 params = tuple() 30 31 def create_new_class(func): 32 cls = type(func.__name__, (AnalysisTask,), dict(eval=func, _params=params)) 33 return cls 34 35 return create_new_class 36 37 38@analysis_task(("field",)) 39def MaximumValue(params, data_object): 40 v = data_object.quantities["MaxLocation"](params.field)[0] 41 return v 42 43 44@analysis_task() 45def CurrentTimeYears(params, ds): 46 return ds.current_time * ds["years"] 47 48 49class SlicePlotDataset(AnalysisTask): 50 _params = ["field", "axis", "center"] 51 52 def __init__(self, *args, **kwargs): 53 from yt.visualization.api import SlicePlot 54 55 self.SlicePlot = SlicePlot 56 AnalysisTask.__init__(self, *args, **kwargs) 57 58 def eval(self, ds): 59 slc = self.SlicePlot(ds, self.axis, self.field, center=self.center) 60 return slc.save() 61 62 63class QuantityProxy(AnalysisTask): 64 _params = None 65 quantity_name = None 66 67 def __repr__(self): 68 # Stolen from YTDataContainer.__repr__ 69 s = f"{self.__class__.__name__}: " 70 s += ", ".join(["%s" % [arg for arg in self.args]]) 71 s += ", ".join(f"{k}={v}" for k, v in self.kwargs.items()) 72 return s 73 74 def __init__(self, *args, **kwargs): 75 self.args = args 76 self.kwargs = kwargs 77 78 def eval(self, data_object): 79 rv = data_object.quantities[self.quantity_name](*self.args, **self.kwargs) 80 return rv 81 82 83class ParameterValue(AnalysisTask): 84 _params = ["parameter"] 85 86 def __init__(self, parameter, cast=None): 87 self.parameter = parameter 88 if cast is None: 89 90 def _identity(x): 91 return x 92 93 cast = _identity 94 self.cast = cast 95 96 def eval(self, ds): 97 return self.cast(ds.get_parameter(self.parameter)) 98 99 100def create_quantity_proxy(quantity_object): 101 args, varargs, kwargs, defaults = inspect.getargspec(quantity_object[1]) 102 # Strip off 'data' which is on every quantity function 103 params = args[1:] 104 if kwargs is not None: 105 params += kwargs 106 dd = dict(_params=params, quantity_name=quantity_object[0]) 107 cls = type(quantity_object[0], (QuantityProxy,), dd) 108 return cls 109