1import functools 2import glob 3import inspect 4import os 5import weakref 6from functools import wraps 7 8import numpy as np 9from more_itertools import always_iterable 10 11from yt._maintenance.deprecation import issue_deprecation_warning 12from yt.config import ytcfg 13from yt.data_objects.analyzer_objects import AnalysisTask, create_quantity_proxy 14from yt.data_objects.particle_trajectories import ParticleTrajectories 15from yt.funcs import is_sequence, mylog 16from yt.units.yt_array import YTArray, YTQuantity 17from yt.utilities.exceptions import YTException 18from yt.utilities.object_registries import ( 19 analysis_task_registry, 20 data_object_registry, 21 derived_quantity_registry, 22 simulation_time_series_registry, 23) 24from yt.utilities.parallel_tools.parallel_analysis_interface import ( 25 communication_system, 26 parallel_objects, 27 parallel_root_only, 28) 29 30 31class AnalysisTaskProxy: 32 def __init__(self, time_series): 33 self.time_series = time_series 34 35 def __getitem__(self, key): 36 task_cls = analysis_task_registry[key] 37 38 @wraps(task_cls.__init__) 39 def func(*args, **kwargs): 40 task = task_cls(*args, **kwargs) 41 return self.time_series.eval(task) 42 43 return func 44 45 def keys(self): 46 return analysis_task_registry.keys() 47 48 def __contains__(self, key): 49 return key in analysis_task_registry 50 51 52def get_ds_prop(propname): 53 def _eval(params, ds): 54 return getattr(ds, propname) 55 56 cls = type(propname, (AnalysisTask,), dict(eval=_eval, _params=tuple())) 57 return cls 58 59 60attrs = ( 61 "refine_by", 62 "dimensionality", 63 "current_time", 64 "domain_dimensions", 65 "domain_left_edge", 66 "domain_right_edge", 67 "unique_identifier", 68 "current_redshift", 69 "cosmological_simulation", 70 "omega_matter", 71 "omega_lambda", 72 "omega_radiation", 73 "hubble_constant", 74) 75 76 77class TimeSeriesParametersContainer: 78 def __init__(self, data_object): 79 self.data_object = data_object 80 81 def __getattr__(self, attr): 82 if attr in attrs: 83 return self.data_object.eval(get_ds_prop(attr)()) 84 raise AttributeError(attr) 85 86 87class DatasetSeries: 88 r"""The DatasetSeries object is a container of multiple datasets, 89 allowing easy iteration and computation on them. 90 91 DatasetSeries objects are designed to provide easy ways to access, 92 analyze, parallelize and visualize multiple datasets sequentially. This is 93 primarily expressed through iteration, but can also be constructed via 94 analysis tasks (see :ref:`time-series-analysis`). 95 96 Note that contained datasets are lazily loaded and weakly referenced. This means 97 that in order to perform follow-up operations on data it's best to define handles on 98 these datasets during iteration. 99 100 Parameters 101 ---------- 102 outputs : list of filenames, or pattern 103 A list of filenames, for instance ["DD0001/DD0001", "DD0002/DD0002"], 104 or a glob pattern (i.e. containing wildcards '[]?!*') such as "DD*/DD*.index". 105 In the latter case, results are sorted automatically. 106 Filenames and patterns can be of type str, os.Pathlike or bytes. 107 parallel : True, False or int 108 This parameter governs the behavior when .piter() is called on the 109 resultant DatasetSeries object. If this is set to False, the time 110 series will not iterate in parallel when .piter() is called. If 111 this is set to either True, one processor will be allocated for 112 each iteration of the loop. If this is set to an integer, the loop 113 will be parallelized over this many workgroups. It the integer 114 value is less than the total number of available processors, 115 more than one processor will be allocated to a given loop iteration, 116 causing the functionality within the loop to be run in parallel. 117 setup_function : callable, accepts a ds 118 This function will be called whenever a dataset is loaded. 119 mixed_dataset_types : True or False, default False 120 Set to True if the DatasetSeries will load different dataset types, set 121 to False if loading dataset of a single type as this will result in a 122 considerable speed up from not having to figure out the dataset type. 123 124 Examples 125 -------- 126 127 >>> ts = DatasetSeries( 128 ... "GasSloshingLowRes/sloshing_low_res_hdf5_plt_cnt_0[0-6][0-9]0" 129 ... ) 130 >>> for ds in ts: 131 ... SlicePlot(ds, "x", ("gas", "density")).save() 132 ... 133 >>> def print_time(ds): 134 ... print(ds.current_time) 135 ... 136 >>> ts = DatasetSeries( 137 ... "GasSloshingLowRes/sloshing_low_res_hdf5_plt_cnt_0[0-6][0-9]0", 138 ... setup_function=print_time, 139 ... ) 140 ... 141 >>> for ds in ts: 142 ... SlicePlot(ds, "x", ("gas", "density")).save() 143 144 """ 145 146 def __init_subclass__(cls, *args, **kwargs): 147 super().__init_subclass__(*args, **kwargs) 148 code_name = cls.__name__[: cls.__name__.find("Simulation")] 149 if code_name: 150 simulation_time_series_registry[code_name] = cls 151 mylog.debug("Registering simulation: %s as %s", code_name, cls) 152 153 def __new__(cls, outputs, *args, **kwargs): 154 try: 155 outputs = cls._get_filenames_from_glob_pattern(outputs) 156 except TypeError: 157 pass 158 ret = super().__new__(cls) 159 ret._pre_outputs = outputs[:] 160 return ret 161 162 def __init__( 163 self, 164 outputs, 165 parallel=True, 166 setup_function=None, 167 mixed_dataset_types=False, 168 **kwargs, 169 ): 170 # This is needed to properly set _pre_outputs for Simulation subclasses. 171 self._mixed_dataset_types = mixed_dataset_types 172 if is_sequence(outputs) and not isinstance(outputs, str): 173 self._pre_outputs = outputs[:] 174 self.tasks = AnalysisTaskProxy(self) 175 self.params = TimeSeriesParametersContainer(self) 176 if setup_function is None: 177 178 def _null(x): 179 return None 180 181 setup_function = _null 182 self._setup_function = setup_function 183 for type_name in data_object_registry: 184 setattr( 185 self, type_name, functools.partial(DatasetSeriesObject, self, type_name) 186 ) 187 self.parallel = parallel 188 self.kwargs = kwargs 189 190 @staticmethod 191 def _get_filenames_from_glob_pattern(outputs): 192 """ 193 Helper function to DatasetSeries.__new__ 194 handle a special case where "outputs" is assumed to be really a pattern string 195 """ 196 pattern = outputs 197 epattern = os.path.expanduser(pattern) 198 data_dir = ytcfg.get("yt", "test_data_dir") 199 # if no match if found from the current work dir, 200 # we try to match the pattern from the test data dir 201 file_list = glob.glob(epattern) or glob.glob(os.path.join(data_dir, epattern)) 202 if not file_list: 203 raise FileNotFoundError(f"No match found for pattern : {pattern}") 204 return sorted(file_list) 205 206 def __getitem__(self, key): 207 if isinstance(key, slice): 208 if isinstance(key.start, float): 209 return self.get_range(key.start, key.stop) 210 # This will return a sliced up object! 211 return DatasetSeries( 212 self._pre_outputs[key], parallel=self.parallel, **self.kwargs 213 ) 214 o = self._pre_outputs[key] 215 if isinstance(o, (str, os.PathLike)): 216 o = self._load(o, **self.kwargs) 217 self._setup_function(o) 218 return o 219 220 def __len__(self): 221 return len(self._pre_outputs) 222 223 @property 224 def outputs(self): 225 return self._pre_outputs 226 227 def piter(self, storage=None, dynamic=False): 228 r"""Iterate over time series components in parallel. 229 230 This allows you to iterate over a time series while dispatching 231 individual components of that time series to different processors or 232 processor groups. If the parallelism strategy was set to be 233 multi-processor (by "parallel = N" where N is an integer when the 234 DatasetSeries was created) this will issue each dataset to an 235 N-processor group. For instance, this would allow you to start a 1024 236 processor job, loading up 100 datasets in a time series and creating 8 237 processor groups of 128 processors each, each of which would be 238 assigned a different dataset. This could be accomplished as shown in 239 the examples below. The *storage* option is as seen in 240 :func:`~yt.utilities.parallel_tools.parallel_analysis_interface.parallel_objects` 241 which is a mechanism for storing results of analysis on an individual 242 dataset and then combining the results at the end, so that the entire 243 set of processors have access to those results. 244 245 Note that supplying a *store* changes the iteration mechanism; see 246 below. 247 248 Parameters 249 ---------- 250 storage : dict 251 This is a dictionary, which will be filled with results during the 252 course of the iteration. The keys will be the dataset 253 indices and the values will be whatever is assigned to the *result* 254 attribute on the storage during iteration. 255 dynamic : boolean 256 This governs whether or not dynamic load balancing will be 257 enabled. This requires one dedicated processor; if this 258 is enabled with a set of 128 processors available, only 259 127 will be available to iterate over objects as one will 260 be load balancing the rest. 261 262 263 Examples 264 -------- 265 Here is an example of iteration when the results do not need to be 266 stored. One processor will be assigned to each dataset. 267 268 >>> ts = DatasetSeries("DD*/DD*.index") 269 >>> for ds in ts.piter(): 270 ... SlicePlot(ds, "x", ("gas", "density")).save() 271 ... 272 273 This demonstrates how one might store results: 274 275 >>> def print_time(ds): 276 ... print(ds.current_time) 277 ... 278 >>> ts = DatasetSeries("DD*/DD*.index", setup_function=print_time) 279 ... 280 >>> my_storage = {} 281 >>> for sto, ds in ts.piter(storage=my_storage): 282 ... v, c = ds.find_max(("gas", "density")) 283 ... sto.result = (v, c) 284 ... 285 >>> for i, (v, c) in sorted(my_storage.items()): 286 ... print("% 4i %0.3e" % (i, v)) 287 ... 288 289 This shows how to dispatch 4 processors to each dataset: 290 291 >>> ts = DatasetSeries("DD*/DD*.index", parallel=4) 292 >>> for ds in ts.piter(): 293 ... ProjectionPlot(ds, "x", ("gas", "density")).save() 294 ... 295 296 """ 297 if not self.parallel: 298 njobs = 1 299 elif not dynamic: 300 if self.parallel: 301 njobs = -1 302 else: 303 njobs = self.parallel 304 else: 305 my_communicator = communication_system.communicators[-1] 306 nsize = my_communicator.size 307 if nsize == 1: 308 self.parallel = False 309 dynamic = False 310 njobs = 1 311 else: 312 njobs = nsize - 1 313 314 for output in parallel_objects( 315 self._pre_outputs, njobs=njobs, storage=storage, dynamic=dynamic 316 ): 317 if storage is not None: 318 sto, output = output 319 320 if isinstance(output, str): 321 ds = self._load(output, **self.kwargs) 322 self._setup_function(ds) 323 else: 324 ds = output 325 326 if storage is not None: 327 next_ret = (sto, ds) 328 else: 329 next_ret = ds 330 331 yield next_ret 332 333 def eval(self, tasks, obj=None): 334 return_values = {} 335 for store, ds in self.piter(return_values): 336 store.result = [] 337 for task in always_iterable(tasks): 338 try: 339 style = inspect.getargspec(task.eval)[0][1] 340 if style == "ds": 341 arg = ds 342 elif style == "data_object": 343 if obj is None: 344 obj = DatasetSeriesObject(self, "all_data") 345 arg = obj.get(ds) 346 rv = task.eval(arg) 347 # We catch and store YT-originating exceptions 348 # This fixes the standard problem of having a sphere that's too 349 # small. 350 except YTException: 351 pass 352 store.result.append(rv) 353 return [v for k, v in sorted(return_values.items())] 354 355 @classmethod 356 def from_filenames(cls, filenames, parallel=True, setup_function=None, **kwargs): 357 r"""Create a time series from either a filename pattern or a list of 358 filenames. 359 360 This method provides an easy way to create a 361 :class:`~yt.data_objects.time_series.DatasetSeries`, given a set of 362 filenames or a pattern that matches them. Additionally, it can set the 363 parallelism strategy. 364 365 Parameters 366 ---------- 367 filenames : list or pattern 368 This can either be a list of filenames (such as ["DD0001/DD0001", 369 "DD0002/DD0002"]) or a pattern to match, such as 370 "DD*/DD*.index"). If it's the former, they will be loaded in 371 order. The latter will be identified with the glob module and then 372 sorted. 373 parallel : True, False or int 374 This parameter governs the behavior when .piter() is called on the 375 resultant DatasetSeries object. If this is set to False, the time 376 series will not iterate in parallel when .piter() is called. If 377 this is set to either True or an integer, it will be iterated with 378 1 or that integer number of processors assigned to each parameter 379 file provided to the loop. 380 setup_function : callable, accepts a ds 381 This function will be called whenever a dataset is loaded. 382 383 Examples 384 -------- 385 386 >>> def print_time(ds): 387 ... print(ds.current_time) 388 ... 389 >>> ts = DatasetSeries.from_filenames( 390 ... "GasSloshingLowRes/sloshing_low_res_hdf5_plt_cnt_0[0-6][0-9]0", 391 ... setup_function=print_time, 392 ... ) 393 ... 394 >>> for ds in ts: 395 ... SlicePlot(ds, "x", ("gas", "density")).save() 396 397 """ 398 issue_deprecation_warning( 399 "DatasetSeries.from_filenames() is deprecated and will be removed " 400 "in a future version of yt. Use DatasetSeries() directly.", 401 since="4.0.0", 402 removal="4.1.0", 403 ) 404 obj = cls(filenames, parallel=parallel, setup_function=setup_function, **kwargs) 405 return obj 406 407 @classmethod 408 def from_output_log(cls, output_log, line_prefix="DATASET WRITTEN", parallel=True): 409 filenames = [] 410 for line in open(output_log): 411 if not line.startswith(line_prefix): 412 continue 413 cut_line = line[len(line_prefix) :].strip() 414 fn = cut_line.split()[0] 415 filenames.append(fn) 416 obj = cls(filenames, parallel=parallel) 417 return obj 418 419 _dataset_cls = None 420 421 def _load(self, output_fn, **kwargs): 422 from yt.loaders import load 423 424 if self._dataset_cls is not None: 425 return self._dataset_cls(output_fn, **kwargs) 426 elif self._mixed_dataset_types: 427 return load(output_fn, **kwargs) 428 ds = load(output_fn, **kwargs) 429 self._dataset_cls = ds.__class__ 430 return ds 431 432 def particle_trajectories( 433 self, indices, fields=None, suppress_logging=False, ptype=None 434 ): 435 r"""Create a collection of particle trajectories in time over a series of 436 datasets. 437 438 Parameters 439 ---------- 440 indices : array_like 441 An integer array of particle indices whose trajectories we 442 want to track. If they are not sorted they will be sorted. 443 fields : list of strings, optional 444 A set of fields that is retrieved when the trajectory 445 collection is instantiated. Default: None (will default 446 to the fields 'particle_position_x', 'particle_position_y', 447 'particle_position_z') 448 suppress_logging : boolean 449 Suppress yt's logging when iterating over the simulation time 450 series. Default: False 451 ptype : str, optional 452 Only use this particle type. Default: None, which uses all particle type. 453 454 Examples 455 -------- 456 >>> my_fns = glob.glob("orbit_hdf5_chk_00[0-9][0-9]") 457 >>> my_fns.sort() 458 >>> fields = [ 459 ... ("all", "particle_position_x"), 460 ... ("all", "particle_position_y"), 461 ... ("all", "particle_position_z"), 462 ... ("all", "particle_velocity_x"), 463 ... ("all", "particle_velocity_y"), 464 ... ("all", "particle_velocity_z"), 465 ... ] 466 >>> ds = load(my_fns[0]) 467 >>> init_sphere = ds.sphere(ds.domain_center, (0.5, "unitary")) 468 >>> indices = init_sphere[("all", "particle_index")].astype("int") 469 >>> ts = DatasetSeries(my_fns) 470 >>> trajs = ts.particle_trajectories(indices, fields=fields) 471 >>> for t in trajs: 472 ... print( 473 ... t[("all", "particle_velocity_x")].max(), 474 ... t[("all", "particle_velocity_x")].min(), 475 ... ) 476 477 Notes 478 ----- 479 This function will fail if there are duplicate particle ids or if some of the 480 particle disappear. 481 """ 482 return ParticleTrajectories( 483 self, indices, fields=fields, suppress_logging=suppress_logging, ptype=ptype 484 ) 485 486 487class TimeSeriesQuantitiesContainer: 488 def __init__(self, data_object, quantities): 489 self.data_object = data_object 490 self.quantities = quantities 491 492 def __getitem__(self, key): 493 if key not in self.quantities: 494 raise KeyError(key) 495 q = self.quantities[key] 496 497 def run_quantity_wrapper(quantity, quantity_name): 498 @wraps(derived_quantity_registry[quantity_name][1]) 499 def run_quantity(*args, **kwargs): 500 to_run = quantity(*args, **kwargs) 501 return self.data_object.eval(to_run) 502 503 return run_quantity 504 505 return run_quantity_wrapper(q, key) 506 507 508class DatasetSeriesObject: 509 def __init__(self, time_series, data_object_name, *args, **kwargs): 510 self.time_series = weakref.proxy(time_series) 511 self.data_object_name = data_object_name 512 self._args = args 513 self._kwargs = kwargs 514 qs = { 515 qn: create_quantity_proxy(qv) 516 for qn, qv in derived_quantity_registry.items() 517 } 518 self.quantities = TimeSeriesQuantitiesContainer(self, qs) 519 520 def eval(self, tasks): 521 return self.time_series.eval(tasks, self) 522 523 def get(self, ds): 524 # We get the type name, which corresponds to an attribute of the 525 # index 526 cls = getattr(ds, self.data_object_name) 527 return cls(*self._args, **self._kwargs) 528 529 530class SimulationTimeSeries(DatasetSeries): 531 def __init__(self, parameter_filename, find_outputs=False): 532 """ 533 Base class for generating simulation time series types. 534 Principally consists of a *parameter_filename*. 535 """ 536 537 if not os.path.exists(parameter_filename): 538 raise FileNotFoundError(parameter_filename) 539 self.parameter_filename = parameter_filename 540 self.basename = os.path.basename(parameter_filename) 541 self.directory = os.path.dirname(parameter_filename) 542 self.parameters = {} 543 self.key_parameters = [] 544 545 # Set some parameter defaults. 546 self._set_parameter_defaults() 547 # Read the simulation dataset. 548 self._parse_parameter_file() 549 # Set units 550 self._set_units() 551 # Figure out the starting and stopping times and redshift. 552 self._calculate_simulation_bounds() 553 # Get all possible datasets. 554 self._get_all_outputs(find_outputs=find_outputs) 555 556 self.print_key_parameters() 557 558 def _set_parameter_defaults(self): 559 pass 560 561 def _parse_parameter_file(self): 562 pass 563 564 def _set_units(self): 565 pass 566 567 def _calculate_simulation_bounds(self): 568 pass 569 570 def _get_all_outputs(**kwargs): 571 pass 572 573 def __repr__(self): 574 return self.parameter_filename 575 576 _arr = None 577 578 @property 579 def arr(self): 580 if self._arr is not None: 581 return self._arr 582 self._arr = functools.partial(YTArray, registry=self.unit_registry) 583 return self._arr 584 585 _quan = None 586 587 @property 588 def quan(self): 589 if self._quan is not None: 590 return self._quan 591 self._quan = functools.partial(YTQuantity, registry=self.unit_registry) 592 return self._quan 593 594 @parallel_root_only 595 def print_key_parameters(self): 596 """ 597 Print out some key parameters for the simulation. 598 """ 599 if self.simulation_type == "grid": 600 for a in ["domain_dimensions", "domain_left_edge", "domain_right_edge"]: 601 self._print_attr(a) 602 for a in ["initial_time", "final_time", "cosmological_simulation"]: 603 self._print_attr(a) 604 if getattr(self, "cosmological_simulation", False): 605 for a in [ 606 "box_size", 607 "omega_matter", 608 "omega_lambda", 609 "omega_radiation", 610 "hubble_constant", 611 "initial_redshift", 612 "final_redshift", 613 ]: 614 self._print_attr(a) 615 for a in self.key_parameters: 616 self._print_attr(a) 617 mylog.info("Total datasets: %d.", len(self.all_outputs)) 618 619 def _print_attr(self, a): 620 """ 621 Print the attribute or warn about it missing. 622 """ 623 if not hasattr(self, a): 624 mylog.error("Missing %s in dataset definition!", a) 625 return 626 v = getattr(self, a) 627 mylog.info("Parameters: %-25s = %s", a, v) 628 629 def _get_outputs_by_key(self, key, values, tolerance=None, outputs=None): 630 r""" 631 Get datasets at or near to given values. 632 633 Parameters 634 ---------- 635 key : str 636 The key by which to retrieve outputs, usually 'time' or 637 'redshift'. 638 values : array_like 639 A list of values, given as floats. 640 tolerance : float 641 If not None, do not return a dataset unless the value is 642 within the tolerance value. If None, simply return the 643 nearest dataset. 644 Default: None. 645 outputs : list 646 The list of outputs from which to choose. If None, 647 self.all_outputs is used. 648 Default: None. 649 650 Examples 651 -------- 652 >>> datasets = es.get_outputs_by_key("redshift", [0, 1, 2], tolerance=0.1) 653 654 """ 655 656 if not isinstance(values, YTArray): 657 if isinstance(values, tuple) and len(values) == 2: 658 values = self.arr(*values) 659 else: 660 values = self.arr(values) 661 values = values.in_base() 662 663 if outputs is None: 664 outputs = self.all_outputs 665 my_outputs = [] 666 if not outputs: 667 return my_outputs 668 for value in values: 669 outputs.sort(key=lambda obj: np.abs(value - obj[key])) 670 if ( 671 tolerance is None or np.abs(value - outputs[0][key]) <= tolerance 672 ) and outputs[0] not in my_outputs: 673 my_outputs.append(outputs[0]) 674 else: 675 mylog.error("No dataset added for %s = %f.", key, value) 676 677 outputs.sort(key=lambda obj: obj["time"]) 678 return my_outputs 679