1import os 2import weakref 3from collections import defaultdict 4from contextlib import contextmanager 5 6import numpy as np 7 8from yt.data_objects.field_data import YTFieldData 9from yt.data_objects.profiles import create_profile 10from yt.fields.field_exceptions import NeedsGridType 11from yt.frontends.ytdata.utilities import save_as_dataset 12from yt.funcs import get_output_filename, is_sequence, iter_fields, mylog 13from yt.units.yt_array import YTArray, YTQuantity, uconcatenate 14from yt.utilities.amr_kdtree.api import AMRKDTree 15from yt.utilities.exceptions import ( 16 YTCouldNotGenerateField, 17 YTException, 18 YTFieldNotFound, 19 YTFieldNotParseable, 20 YTFieldTypeNotFound, 21 YTNonIndexedDataContainer, 22 YTSpatialFieldUnitError, 23) 24from yt.utilities.object_registries import data_object_registry 25from yt.utilities.parameter_file_storage import ParameterFileStore 26 27 28def sanitize_weight_field(ds, field, weight): 29 field_object = ds._get_field_info(field) 30 if weight is None: 31 if field_object.sampling_type == "particle": 32 if field_object.name[0] == "gas": 33 ptype = ds._sph_ptypes[0] 34 else: 35 ptype = field_object.name[0] 36 weight_field = (ptype, "particle_ones") 37 else: 38 weight_field = ("index", "ones") 39 else: 40 weight_field = weight 41 return weight_field 42 43 44def _get_ipython_key_completion(ds): 45 # tuple-completion (ftype, fname) was added in IPython 8.0.0 46 # with earlier versions, completion works with fname only 47 # this implementation should work transparently with all IPython versions 48 tuple_keys = ds.field_list + ds.derived_field_list 49 fnames = list({k[1] for k in tuple_keys}) 50 return tuple_keys + fnames 51 52 53class YTDataContainer: 54 """ 55 Generic YTDataContainer container. By itself, will attempt to 56 generate field, read fields (method defined by derived classes) 57 and deal with passing back and forth field parameters. 58 """ 59 60 _chunk_info = None 61 _num_ghost_zones = 0 62 _con_args = () 63 _skip_add = False 64 _container_fields = () 65 _tds_attrs = () 66 _tds_fields = () 67 _field_cache = None 68 _index = None 69 70 def __init__(self, ds, field_parameters): 71 """ 72 Typically this is never called directly, but only due to inheritance. 73 It associates a :class:`~yt.data_objects.static_output.Dataset` with the class, 74 sets its initial set of fields, and the remainder of the arguments 75 are passed as field_parameters. 76 """ 77 # ds is typically set in the new object type created in 78 # Dataset._add_object_class but it can also be passed as a parameter to the 79 # constructor, in which case it will override the default. 80 # This code ensures it is never not set. 81 if ds is not None: 82 self.ds = ds 83 else: 84 if not hasattr(self, "ds"): 85 raise RuntimeError( 86 "Error: ds must be set either through class type " 87 "or parameter to the constructor" 88 ) 89 90 self._current_particle_type = "all" 91 self._current_fluid_type = self.ds.default_fluid_type 92 self.ds.objects.append(weakref.proxy(self)) 93 mylog.debug("Appending object to %s (type: %s)", self.ds, type(self)) 94 self.field_data = YTFieldData() 95 if self.ds.unit_system.has_current_mks: 96 mag_unit = "T" 97 else: 98 mag_unit = "G" 99 self._default_field_parameters = { 100 "center": self.ds.arr(np.zeros(3, dtype="float64"), "cm"), 101 "bulk_velocity": self.ds.arr(np.zeros(3, dtype="float64"), "cm/s"), 102 "bulk_magnetic_field": self.ds.arr(np.zeros(3, dtype="float64"), mag_unit), 103 "normal": self.ds.arr([0.0, 0.0, 1.0], ""), 104 } 105 if field_parameters is None: 106 field_parameters = {} 107 self._set_default_field_parameters() 108 for key, val in field_parameters.items(): 109 self.set_field_parameter(key, val) 110 111 def __init_subclass__(cls, *args, **kwargs): 112 super().__init_subclass__(*args, **kwargs) 113 if hasattr(cls, "_type_name") and not cls._skip_add: 114 name = getattr(cls, "_override_selector_name", cls._type_name) 115 data_object_registry[name] = cls 116 117 @property 118 def pf(self): 119 return getattr(self, "ds", None) 120 121 @property 122 def index(self): 123 if self._index is not None: 124 return self._index 125 self._index = self.ds.index 126 return self._index 127 128 def _debug(self): 129 """ 130 When called from within a derived field, this will run pdb. However, 131 during field detection, it will not. This allows you to more easily 132 debug fields that are being called on actual objects. 133 """ 134 import pdb 135 136 pdb.set_trace() 137 138 def _set_default_field_parameters(self): 139 self.field_parameters = {} 140 for k, v in self._default_field_parameters.items(): 141 self.set_field_parameter(k, v) 142 143 def _is_default_field_parameter(self, parameter): 144 if parameter not in self._default_field_parameters: 145 return False 146 return ( 147 self._default_field_parameters[parameter] 148 is self.field_parameters[parameter] 149 ) 150 151 def apply_units(self, arr, units): 152 try: 153 arr.units.registry = self.ds.unit_registry 154 return arr.to(units) 155 except AttributeError: 156 return self.ds.arr(arr, units=units) 157 158 def _first_matching_field(self, field): 159 for ftype, fname in self.ds.derived_field_list: 160 if fname == field: 161 return (ftype, fname) 162 163 raise YTFieldNotFound(field, self.ds) 164 165 def _set_center(self, center): 166 if center is None: 167 self.center = None 168 return 169 elif isinstance(center, YTArray): 170 self.center = self.ds.arr(center.astype("float64")) 171 self.center.convert_to_units("code_length") 172 elif isinstance(center, (list, tuple, np.ndarray)): 173 if isinstance(center[0], YTQuantity): 174 self.center = self.ds.arr([c.copy() for c in center], dtype="float64") 175 self.center.convert_to_units("code_length") 176 else: 177 self.center = self.ds.arr(center, "code_length", dtype="float64") 178 elif isinstance(center, str): 179 if center.lower() in ("c", "center"): 180 self.center = self.ds.domain_center 181 # is this dangerous for race conditions? 182 elif center.lower() in ("max", "m"): 183 self.center = self.ds.find_max(("gas", "density"))[1] 184 elif center.startswith("max_"): 185 field = self._first_matching_field(center[4:]) 186 self.center = self.ds.find_max(field)[1] 187 elif center.lower() == "min": 188 self.center = self.ds.find_min(("gas", "density"))[1] 189 elif center.startswith("min_"): 190 field = self._first_matching_field(center[4:]) 191 self.center = self.ds.find_min(field)[1] 192 else: 193 self.center = self.ds.arr(center, "code_length", dtype="float64") 194 195 if self.center.ndim > 1: 196 mylog.debug("Removing singleton dimensions from 'center'.") 197 self.center = np.squeeze(self.center) 198 if self.center.ndim > 1: 199 msg = ( 200 "center array must be 1 dimensional, supplied center has " 201 f"{self.center.ndim} dimensions with shape {self.center.shape}." 202 ) 203 raise YTException(msg) 204 205 self.set_field_parameter("center", self.center) 206 207 def get_field_parameter(self, name, default=None): 208 """ 209 This is typically only used by derived field functions, but 210 it returns parameters used to generate fields. 211 """ 212 if name in self.field_parameters: 213 return self.field_parameters[name] 214 else: 215 return default 216 217 def set_field_parameter(self, name, val): 218 """ 219 Here we set up dictionaries that get passed up and down and ultimately 220 to derived fields. 221 """ 222 self.field_parameters[name] = val 223 224 def has_field_parameter(self, name): 225 """ 226 Checks if a field parameter is set. 227 """ 228 return name in self.field_parameters 229 230 def clear_data(self): 231 """ 232 Clears out all data from the YTDataContainer instance, freeing memory. 233 """ 234 self.field_data.clear() 235 236 def has_key(self, key): 237 """ 238 Checks if a data field already exists. 239 """ 240 return key in self.field_data 241 242 def keys(self): 243 return self.field_data.keys() 244 245 def _reshape_vals(self, arr): 246 return arr 247 248 def __getitem__(self, key): 249 """ 250 Returns a single field. Will add if necessary. 251 """ 252 f = self._determine_fields([key])[0] 253 if f not in self.field_data and key not in self.field_data: 254 if f in self._container_fields: 255 self.field_data[f] = self.ds.arr(self._generate_container_field(f)) 256 return self.field_data[f] 257 else: 258 self.get_data(f) 259 # fi.units is the unit expression string. We depend on the registry 260 # hanging off the dataset to define this unit object. 261 # Note that this is less succinct so that we can account for the case 262 # when there are, for example, no elements in the object. 263 try: 264 rv = self.field_data[f] 265 except KeyError: 266 if isinstance(f, tuple): 267 fi = self.ds._get_field_info(*f) 268 elif isinstance(f, bytes): 269 fi = self.ds._get_field_info("unknown", f) 270 rv = self.ds.arr(self.field_data[key], fi.units) 271 return rv 272 273 def _ipython_key_completions_(self): 274 return _get_ipython_key_completion(self.ds) 275 276 def __setitem__(self, key, val): 277 """ 278 Sets a field to be some other value. 279 """ 280 self.field_data[key] = val 281 282 def __delitem__(self, key): 283 """ 284 Deletes a field 285 """ 286 if key not in self.field_data: 287 key = self._determine_fields(key)[0] 288 del self.field_data[key] 289 290 def _generate_field(self, field): 291 ftype, fname = field 292 finfo = self.ds._get_field_info(*field) 293 with self._field_type_state(ftype, finfo): 294 if fname in self._container_fields: 295 tr = self._generate_container_field(field) 296 if finfo.sampling_type == "particle": 297 tr = self._generate_particle_field(field) 298 else: 299 tr = self._generate_fluid_field(field) 300 if tr is None: 301 raise YTCouldNotGenerateField(field, self.ds) 302 return tr 303 304 def _generate_fluid_field(self, field): 305 # First we check the validator 306 ftype, fname = field 307 finfo = self.ds._get_field_info(ftype, fname) 308 if self._current_chunk is None or self._current_chunk.chunk_type != "spatial": 309 gen_obj = self 310 else: 311 gen_obj = self._current_chunk.objs[0] 312 gen_obj.field_parameters = self.field_parameters 313 try: 314 finfo.check_available(gen_obj) 315 except NeedsGridType as ngt_exception: 316 rv = self._generate_spatial_fluid(field, ngt_exception.ghost_zones) 317 else: 318 rv = finfo(gen_obj) 319 return rv 320 321 def _generate_spatial_fluid(self, field, ngz): 322 finfo = self.ds._get_field_info(*field) 323 if finfo.units is None: 324 raise YTSpatialFieldUnitError(field) 325 units = finfo.units 326 try: 327 rv = self.ds.arr(np.zeros(self.ires.size, dtype="float64"), units) 328 accumulate = False 329 except YTNonIndexedDataContainer: 330 # In this case, we'll generate many tiny arrays of unknown size and 331 # then concatenate them. 332 outputs = [] 333 accumulate = True 334 ind = 0 335 if ngz == 0: 336 deps = self._identify_dependencies([field], spatial=True) 337 deps = self._determine_fields(deps) 338 for _io_chunk in self.chunks([], "io", cache=False): 339 for _chunk in self.chunks([], "spatial", ngz=0, preload_fields=deps): 340 o = self._current_chunk.objs[0] 341 if accumulate: 342 rv = self.ds.arr(np.empty(o.ires.size, dtype="float64"), units) 343 outputs.append(rv) 344 ind = 0 # Does this work with mesh? 345 with o._activate_cache(): 346 ind += o.select( 347 self.selector, source=self[field], dest=rv, offset=ind 348 ) 349 else: 350 chunks = self.index._chunk(self, "spatial", ngz=ngz) 351 for chunk in chunks: 352 with self._chunked_read(chunk): 353 gz = self._current_chunk.objs[0] 354 gz.field_parameters = self.field_parameters 355 wogz = gz._base_grid 356 if accumulate: 357 rv = self.ds.arr( 358 np.empty(wogz.ires.size, dtype="float64"), units 359 ) 360 outputs.append(rv) 361 ind += wogz.select( 362 self.selector, 363 source=gz[field][ngz:-ngz, ngz:-ngz, ngz:-ngz], 364 dest=rv, 365 offset=ind, 366 ) 367 if accumulate: 368 rv = uconcatenate(outputs) 369 return rv 370 371 def _generate_particle_field(self, field): 372 # First we check the validator 373 ftype, fname = field 374 if self._current_chunk is None or self._current_chunk.chunk_type != "spatial": 375 gen_obj = self 376 else: 377 gen_obj = self._current_chunk.objs[0] 378 try: 379 finfo = self.ds._get_field_info(*field) 380 finfo.check_available(gen_obj) 381 except NeedsGridType as ngt_exception: 382 if ngt_exception.ghost_zones != 0: 383 raise NotImplementedError from ngt_exception 384 size = self._count_particles(ftype) 385 rv = self.ds.arr(np.empty(size, dtype="float64"), finfo.units) 386 ind = 0 387 for _io_chunk in self.chunks([], "io", cache=False): 388 for _chunk in self.chunks(field, "spatial"): 389 x, y, z = (self[ftype, f"particle_position_{ax}"] for ax in "xyz") 390 if x.size == 0: 391 continue 392 mask = self._current_chunk.objs[0].select_particles( 393 self.selector, x, y, z 394 ) 395 if mask is None: 396 continue 397 # This requests it from the grid and does NOT mask it 398 data = self[field][mask] 399 rv[ind : ind + data.size] = data 400 ind += data.size 401 else: 402 with self._field_type_state(ftype, finfo, gen_obj): 403 rv = self.ds._get_field_info(*field)(gen_obj) 404 return rv 405 406 def _count_particles(self, ftype): 407 for (f1, _f2), val in self.field_data.items(): 408 if f1 == ftype: 409 return val.size 410 size = 0 411 for _io_chunk in self.chunks([], "io", cache=False): 412 for _chunk in self.chunks([], "spatial"): 413 x, y, z = (self[ftype, f"particle_position_{ax}"] for ax in "xyz") 414 if x.size == 0: 415 continue 416 size += self._current_chunk.objs[0].count_particles( 417 self.selector, x, y, z 418 ) 419 return size 420 421 def _generate_container_field(self, field): 422 raise NotImplementedError 423 424 def _parameter_iterate(self, seq): 425 for obj in seq: 426 old_fp = obj.field_parameters 427 obj.field_parameters = self.field_parameters 428 yield obj 429 obj.field_parameters = old_fp 430 431 _key_fields = None 432 433 def write_out(self, filename, fields=None, format="%0.16e"): 434 """Write out the YTDataContainer object in a text file. 435 436 This function will take a data object and produce a tab delimited text 437 file containing the fields presently existing and the fields given in 438 the ``fields`` list. 439 440 Parameters 441 ---------- 442 filename : String 443 The name of the file to write to. 444 445 fields : List of string, Default = None 446 If this is supplied, these fields will be added to the list of 447 fields to be saved to disk. If not supplied, whatever fields 448 presently exist will be used. 449 450 format : String, Default = "%0.16e" 451 Format of numbers to be written in the file. 452 453 Raises 454 ------ 455 ValueError 456 Raised when there is no existing field. 457 458 YTException 459 Raised when field_type of supplied fields is inconsistent with the 460 field_type of existing fields. 461 462 Examples 463 -------- 464 >>> ds = fake_particle_ds() 465 >>> sp = ds.sphere(ds.domain_center, 0.25) 466 >>> sp.write_out("sphere_1.txt") 467 >>> sp.write_out("sphere_2.txt", fields=["cell_volume"]) 468 """ 469 if fields is None: 470 fields = sorted(self.field_data.keys()) 471 472 if self._key_fields is None: 473 raise ValueError 474 475 field_order = [("index", k) for k in self._key_fields] 476 diff_fields = [field for field in fields if field not in field_order] 477 field_order += diff_fields 478 field_order = sorted(self._determine_fields(field_order)) 479 480 field_shapes = defaultdict(list) 481 for field in field_order: 482 shape = self[field].shape 483 field_shapes[shape].append(field) 484 485 # Check all fields have the same shape 486 if len(field_shapes) != 1: 487 err_msg = ["Got fields with different number of elements:\n"] 488 for shape, these_fields in field_shapes.items(): 489 err_msg.append(f"\t {these_fields} with shape {shape}") 490 raise YTException("\n".join(err_msg)) 491 492 with open(filename, "w") as fid: 493 field_header = [str(f) for f in field_order] 494 fid.write("\t".join(["#"] + field_header + ["\n"])) 495 field_data = np.array([self.field_data[field] for field in field_order]) 496 for line in range(field_data.shape[1]): 497 field_data[:, line].tofile(fid, sep="\t", format=format) 498 fid.write("\n") 499 500 def to_dataframe(self, fields): 501 r"""Export a data object to a :class:`~pandas.DataFrame`. 502 503 This function will take a data object and an optional list of fields 504 and export them to a :class:`~pandas.DataFrame` object. 505 If pandas is not importable, this will raise ImportError. 506 507 Parameters 508 ---------- 509 fields : list of strings or tuple field names 510 This is the list of fields to be exported into 511 the DataFrame. 512 513 Returns 514 ------- 515 df : :class:`~pandas.DataFrame` 516 The data contained in the object. 517 518 Examples 519 -------- 520 >>> dd = ds.all_data() 521 >>> df = dd.to_dataframe([("gas", "density"), ("gas", "temperature")]) 522 """ 523 from yt.utilities.on_demand_imports import _pandas as pd 524 525 data = {} 526 fields = self._determine_fields(fields) 527 for field in fields: 528 data[field[-1]] = self[field] 529 df = pd.DataFrame(data) 530 return df 531 532 def to_astropy_table(self, fields): 533 """ 534 Export region data to a :class:~astropy.table.table.QTable, 535 which is a Table object which is unit-aware. The QTable can then 536 be exported to an ASCII file, FITS file, etc. 537 538 See the AstroPy Table docs for more details: 539 http://docs.astropy.org/en/stable/table/ 540 541 Parameters 542 ---------- 543 fields : list of strings or tuple field names 544 This is the list of fields to be exported into 545 the QTable. 546 547 Examples 548 -------- 549 >>> sp = ds.sphere("c", (1.0, "Mpc")) 550 >>> t = sp.to_astropy_table([("gas", "density"), ("gas", "temperature")]) 551 """ 552 from astropy.table import QTable 553 554 t = QTable() 555 fields = self._determine_fields(fields) 556 for field in fields: 557 t[field[-1]] = self[field].to_astropy() 558 return t 559 560 def save_as_dataset(self, filename=None, fields=None): 561 r"""Export a data object to a reloadable yt dataset. 562 563 This function will take a data object and output a dataset 564 containing either the fields presently existing or fields 565 given in the ``fields`` list. The resulting dataset can be 566 reloaded as a yt dataset. 567 568 Parameters 569 ---------- 570 filename : str, optional 571 The name of the file to be written. If None, the name 572 will be a combination of the original dataset and the type 573 of data container. 574 fields : list of string or tuple field names, optional 575 If this is supplied, it is the list of fields to be saved to 576 disk. If not supplied, all the fields that have been queried 577 will be saved. 578 579 Returns 580 ------- 581 filename : str 582 The name of the file that has been created. 583 584 Examples 585 -------- 586 587 >>> import yt 588 >>> ds = yt.load("enzo_tiny_cosmology/DD0046/DD0046") 589 >>> sp = ds.sphere(ds.domain_center, (10, "Mpc")) 590 >>> fn = sp.save_as_dataset(fields=[("gas", "density"), ("gas", "temperature")]) 591 >>> sphere_ds = yt.load(fn) 592 >>> # the original data container is available as the data attribute 593 >>> print(sds.data[("gas", "density")]) 594 [ 4.46237613e-32 4.86830178e-32 4.46335118e-32 ..., 6.43956165e-30 595 3.57339907e-30 2.83150720e-30] g/cm**3 596 >>> ad = sphere_ds.all_data() 597 >>> print(ad[("gas", "temperature")]) 598 [ 1.00000000e+00 1.00000000e+00 1.00000000e+00 ..., 4.40108359e+04 599 4.54380547e+04 4.72560117e+04] K 600 601 """ 602 603 keyword = f"{str(self.ds)}_{self._type_name}" 604 filename = get_output_filename(filename, keyword, ".h5") 605 606 data = {} 607 if fields is not None: 608 for f in self._determine_fields(fields): 609 data[f] = self[f] 610 else: 611 data.update(self.field_data) 612 # get the extra fields needed to reconstruct the container 613 tds_fields = tuple(("index", t) for t in self._tds_fields) 614 for f in [f for f in self._container_fields + tds_fields if f not in data]: 615 data[f] = self[f] 616 data_fields = list(data.keys()) 617 618 need_grid_positions = False 619 need_particle_positions = False 620 ptypes = [] 621 ftypes = {} 622 for field in data_fields: 623 if field in self._container_fields: 624 ftypes[field] = "grid" 625 need_grid_positions = True 626 elif self.ds.field_info[field].sampling_type == "particle": 627 if field[0] not in ptypes: 628 ptypes.append(field[0]) 629 ftypes[field] = field[0] 630 need_particle_positions = True 631 else: 632 ftypes[field] = "grid" 633 need_grid_positions = True 634 # projections and slices use px and py, so don't need positions 635 if self._type_name in ["cutting", "proj", "slice", "quad_proj"]: 636 need_grid_positions = False 637 638 if need_particle_positions: 639 for ax in self.ds.coordinates.axis_order: 640 for ptype in ptypes: 641 p_field = (ptype, f"particle_position_{ax}") 642 if p_field in self.ds.field_info and p_field not in data: 643 data_fields.append(field) 644 ftypes[p_field] = p_field[0] 645 data[p_field] = self[p_field] 646 if need_grid_positions: 647 for ax in self.ds.coordinates.axis_order: 648 g_field = ("index", ax) 649 if g_field in self.ds.field_info and g_field not in data: 650 data_fields.append(g_field) 651 ftypes[g_field] = "grid" 652 data[g_field] = self[g_field] 653 g_field = ("index", "d" + ax) 654 if g_field in self.ds.field_info and g_field not in data: 655 data_fields.append(g_field) 656 ftypes[g_field] = "grid" 657 data[g_field] = self[g_field] 658 659 extra_attrs = { 660 arg: getattr(self, arg, None) for arg in self._con_args + self._tds_attrs 661 } 662 extra_attrs["con_args"] = repr(self._con_args) 663 extra_attrs["data_type"] = "yt_data_container" 664 extra_attrs["container_type"] = self._type_name 665 extra_attrs["dimensionality"] = self._dimensionality 666 save_as_dataset( 667 self.ds, filename, data, field_types=ftypes, extra_attrs=extra_attrs 668 ) 669 670 return filename 671 672 def to_glue(self, fields, label="yt", data_collection=None): 673 """ 674 Takes specific *fields* in the container and exports them to 675 Glue (http://glueviz.org) for interactive 676 analysis. Optionally add a *label*. If you are already within 677 the Glue environment, you can pass a *data_collection* object, 678 otherwise Glue will be started. 679 """ 680 from glue.core import Data, DataCollection 681 682 from yt.config import ytcfg 683 684 if ytcfg.get("yt", "internals", "within_testing"): 685 from glue.core.application_base import Application as GlueApplication 686 else: 687 try: 688 from glue.app.qt.application import GlueApplication 689 except ImportError: 690 from glue.qt.glue_application import GlueApplication 691 gdata = Data(label=label) 692 for component_name in fields: 693 gdata.add_component(self[component_name], component_name) 694 695 if data_collection is None: 696 dc = DataCollection([gdata]) 697 app = GlueApplication(dc) 698 try: 699 app.start() 700 except AttributeError: 701 # In testing we're using a dummy glue application object 702 # that doesn't have a start method 703 pass 704 else: 705 data_collection.append(gdata) 706 707 def create_firefly_object( 708 self, 709 path_to_firefly, 710 fields_to_include=None, 711 fields_units=None, 712 default_decimation_factor=100, 713 velocity_units="km/s", 714 coordinate_units="kpc", 715 show_unused_fields=0, 716 dataset_name="yt", 717 ): 718 r"""This function links a region of data stored in a yt dataset 719 to the Python frontend API for [Firefly](github.com/ageller/Firefly), 720 a browser-based particle visualization platform. 721 722 Parameters 723 ---------- 724 path_to_firefly : string 725 The (ideally) absolute path to the direction containing the index.html 726 file of Firefly. 727 728 fields_to_include : array_like of strings 729 A list of fields that you want to include in your 730 Firefly visualization for on-the-fly filtering and 731 colormapping. 732 733 default_decimation_factor : integer 734 The factor by which you want to decimate each particle group 735 by (e.g. if there are 1e7 total particles in your simulation 736 you might want to set this to 100 at first). Randomly samples 737 your data like `shuffled_data[::decimation_factor]` so as to 738 not overtax a system. This is adjustable on a per particle group 739 basis by changing the returned reader's 740 `reader.particleGroup[i].decimation_factor` before calling 741 `reader.dumpToJSON()`. 742 743 velocity_units : string 744 The units that the velocity should be converted to in order to 745 show streamlines in Firefly. Defaults to km/s. 746 coordinate_units : string 747 The units that the coordinates should be converted to. Defaults to 748 kpc. 749 show_unused_fields : boolean 750 A flag to optionally print the fields that are available, in the 751 dataset but were not explicitly requested to be tracked. 752 dataset_name : string 753 The name of the subdirectory the JSON files will be stored in 754 (and the name that will appear in startup.json and in the dropdown 755 menu at startup). e.g. `yt` -> json files will appear in 756 `Firefly/data/yt`. 757 758 Returns 759 ------- 760 reader : firefly_api.reader.Reader object 761 A reader object from the firefly_api, configured 762 to output 763 764 Examples 765 -------- 766 767 >>> ramses_ds = yt.load( 768 ... "/Users/agurvich/Desktop/yt_workshop/" 769 ... + "DICEGalaxyDisk_nonCosmological/output_00002/info_00002.txt" 770 ... ) 771 772 >>> region = ramses_ds.sphere(ramses_ds.domain_center, (1000, "kpc")) 773 774 >>> reader = region.create_firefly_object( 775 ... path_to_firefly="/Users/agurvich/research/repos/Firefly", 776 ... fields_to_include=[ 777 ... "particle_extra_field_1", 778 ... "particle_extra_field_2", 779 ... ], 780 ... fields_units=["dimensionless", "dimensionless"], 781 ... dataset_name="IsoGalaxyRamses", 782 ... ) 783 784 >>> reader.options["color"]["io"] = [1, 1, 0, 1] 785 >>> reader.particleGroups[0].decimation_factor = 100 786 >>> reader.dumpToJSON() 787 """ 788 789 ## attempt to import firefly_api 790 try: 791 from firefly_api.particlegroup import ParticleGroup 792 from firefly_api.reader import Reader 793 except ImportError as e: 794 raise ImportError( 795 "Can't find firefly_api, ensure it " 796 "is in your python path or install it with " 797 "`python -m pip install firefly_api`. It is also available " 798 "on github at github.com/agurvich/firefly_api" 799 ) from e 800 801 ## handle default arguments 802 fields_to_include = [] if fields_to_include is None else fields_to_include 803 fields_units = [] if fields_units is None else fields_units 804 805 ## handle input validation, if any 806 if len(fields_units) != len(fields_to_include): 807 raise RuntimeError("Each requested field must have units.") 808 809 ## for safety, in case someone passes a float just cast it 810 default_decimation_factor = int(default_decimation_factor) 811 812 ## initialize a firefly reader instance 813 reader = Reader( 814 JSONdir=os.path.join(path_to_firefly, "data", dataset_name), 815 prefix="ytData", 816 clean_JSONdir=True, 817 ) 818 819 ## create a ParticleGroup object that contains *every* field 820 for ptype in sorted(self.ds.particle_types_raw): 821 ## skip this particle type if it has no particles in this dataset 822 if self[ptype, "relative_particle_position"].shape[0] == 0: 823 continue 824 825 ## loop through the fields and print them to the screen 826 if show_unused_fields: 827 ## read the available extra fields from yt 828 this_ptype_fields = self.ds.particle_fields_by_type[ptype] 829 830 ## load the extra fields and print them 831 for field in this_ptype_fields: 832 if field not in fields_to_include: 833 mylog.warning( 834 "detected (but did not request) %s %s", ptype, field 835 ) 836 837 ## you must have velocities (and they must be named "Velocities") 838 tracked_arrays = [ 839 self[ptype, "relative_particle_velocity"].in_units(velocity_units) 840 ] 841 tracked_names = ["Velocities"] 842 843 ## explicitly go after the fields we want 844 for field, units in zip(fields_to_include, fields_units): 845 ## determine if you want to take the log of the field for Firefly 846 log_flag = "log(" in units 847 848 ## read the field array from the dataset 849 this_field_array = self[ptype, field] 850 851 ## fix the units string and prepend 'log' to the field for 852 ## the UI name 853 if log_flag: 854 units = units[len("log(") : -1] 855 field = f"log{field}" 856 857 ## perform the unit conversion and take the log if 858 ## necessary. 859 this_field_array.in_units(units) 860 if log_flag: 861 this_field_array = np.log10(this_field_array) 862 863 ## add this array to the tracked arrays 864 tracked_arrays += [this_field_array] 865 tracked_names = np.append(tracked_names, [field], axis=0) 866 867 ## flag whether we want to filter and/or color by these fields 868 ## we'll assume yes for both cases, this can be changed after 869 ## the reader object is returned to the user. 870 tracked_filter_flags = np.ones(len(tracked_names)) 871 tracked_colormap_flags = np.ones(len(tracked_names)) 872 873 ## create a firefly ParticleGroup for this particle type 874 pg = ParticleGroup( 875 UIname=ptype, 876 coordinates=self[ptype, "relative_particle_position"].in_units( 877 coordinate_units 878 ), 879 tracked_arrays=tracked_arrays, 880 tracked_names=tracked_names, 881 tracked_filter_flags=tracked_filter_flags, 882 tracked_colormap_flags=tracked_colormap_flags, 883 decimation_factor=default_decimation_factor, 884 ) 885 886 ## bind this particle group to the firefly reader object 887 reader.addParticleGroup(pg) 888 889 return reader 890 891 # Numpy-like Operations 892 def argmax(self, field, axis=None): 893 r"""Return the values at which the field is maximized. 894 895 This will, in a parallel-aware fashion, find the maximum value and then 896 return to you the values at that maximum location that are requested 897 for "axis". By default it will return the spatial positions (in the 898 natural coordinate system), but it can be any field 899 900 Parameters 901 ---------- 902 field : string or tuple field name 903 The field to maximize. 904 axis : string or list of strings, optional 905 If supplied, the fields to sample along; if not supplied, defaults 906 to the coordinate fields. This can be the name of the coordinate 907 fields (i.e., 'x', 'y', 'z') or a list of fields, but cannot be 0, 908 1, 2. 909 910 Returns 911 ------- 912 A list of YTQuantities as specified by the axis argument. 913 914 Examples 915 -------- 916 917 >>> temp_at_max_rho = reg.argmax( 918 ... ("gas", "density"), axis=("gas", "temperature") 919 ... ) 920 >>> max_rho_xyz = reg.argmax(("gas", "density")) 921 >>> t_mrho, v_mrho = reg.argmax( 922 ... ("gas", "density"), 923 ... axis=[("gas", "temperature"), ("gas", "velocity_magnitude")], 924 ... ) 925 >>> x, y, z = reg.argmax(("gas", "density")) 926 927 """ 928 if axis is None: 929 mv, pos0, pos1, pos2 = self.quantities.max_location(field) 930 return pos0, pos1, pos2 931 if isinstance(axis, str): 932 axis = [axis] 933 rv = self.quantities.sample_at_max_field_values(field, axis) 934 if len(rv) == 2: 935 return rv[1] 936 return rv[1:] 937 938 def argmin(self, field, axis=None): 939 r"""Return the values at which the field is minimized. 940 941 This will, in a parallel-aware fashion, find the minimum value and then 942 return to you the values at that minimum location that are requested 943 for "axis". By default it will return the spatial positions (in the 944 natural coordinate system), but it can be any field 945 946 Parameters 947 ---------- 948 field : string or tuple field name 949 The field to minimize. 950 axis : string or list of strings, optional 951 If supplied, the fields to sample along; if not supplied, defaults 952 to the coordinate fields. This can be the name of the coordinate 953 fields (i.e., 'x', 'y', 'z') or a list of fields, but cannot be 0, 954 1, 2. 955 956 Returns 957 ------- 958 A list of YTQuantities as specified by the axis argument. 959 960 Examples 961 -------- 962 963 >>> temp_at_min_rho = reg.argmin( 964 ... ("gas", "density"), axis=("gas", "temperature") 965 ... ) 966 >>> min_rho_xyz = reg.argmin(("gas", "density")) 967 >>> t_mrho, v_mrho = reg.argmin( 968 ... ("gas", "density"), 969 ... axis=[("gas", "temperature"), ("gas", "velocity_magnitude")], 970 ... ) 971 >>> x, y, z = reg.argmin(("gas", "density")) 972 973 """ 974 if axis is None: 975 mv, pos0, pos1, pos2 = self.quantities.min_location(field) 976 return pos0, pos1, pos2 977 if isinstance(axis, str): 978 axis = [axis] 979 rv = self.quantities.sample_at_min_field_values(field, axis) 980 if len(rv) == 2: 981 return rv[1] 982 return rv[1:] 983 984 def _compute_extrema(self, field): 985 if self._extrema_cache is None: 986 self._extrema_cache = {} 987 if field not in self._extrema_cache: 988 # Note we still need to call extrema for each field, as of right 989 # now 990 mi, ma = self.quantities.extrema(field) 991 self._extrema_cache[field] = (mi, ma) 992 return self._extrema_cache[field] 993 994 _extrema_cache = None 995 996 def max(self, field, axis=None): 997 r"""Compute the maximum of a field, optionally along an axis. 998 999 This will, in a parallel-aware fashion, compute the maximum of the 1000 given field. Supplying an axis will result in a return value of a 1001 YTProjection, with method 'mip' for maximum intensity. If the max has 1002 already been requested, it will use the cached extrema value. 1003 1004 Parameters 1005 ---------- 1006 field : string or tuple field name 1007 The field to maximize. 1008 axis : string, optional 1009 If supplied, the axis to project the maximum along. 1010 1011 Returns 1012 ------- 1013 Either a scalar or a YTProjection. 1014 1015 Examples 1016 -------- 1017 1018 >>> max_temp = reg.max(("gas", "temperature")) 1019 >>> max_temp_proj = reg.max(("gas", "temperature"), axis=("index", "x")) 1020 """ 1021 if axis is None: 1022 rv = tuple(self._compute_extrema(f)[1] for f in iter_fields(field)) 1023 if len(rv) == 1: 1024 return rv[0] 1025 return rv 1026 elif axis in self.ds.coordinates.axis_name: 1027 r = self.ds.proj(field, axis, data_source=self, method="mip") 1028 return r 1029 else: 1030 raise NotImplementedError(f"Unknown axis {axis}") 1031 1032 def min(self, field, axis=None): 1033 r"""Compute the minimum of a field. 1034 1035 This will, in a parallel-aware fashion, compute the minimum of the 1036 given field. Supplying an axis is not currently supported. If the max 1037 has already been requested, it will use the cached extrema value. 1038 1039 Parameters 1040 ---------- 1041 field : string or tuple field name 1042 The field to minimize. 1043 axis : string, optional 1044 If supplied, the axis to compute the minimum along. 1045 1046 Returns 1047 ------- 1048 Scalar. 1049 1050 Examples 1051 -------- 1052 1053 >>> min_temp = reg.min(("gas", "temperature")) 1054 """ 1055 if axis is None: 1056 rv = tuple(self._compute_extrema(f)[0] for f in iter_fields(field)) 1057 if len(rv) == 1: 1058 return rv[0] 1059 return rv 1060 elif axis in self.ds.coordinates.axis_name: 1061 raise NotImplementedError("Minimum intensity projection not implemented.") 1062 else: 1063 raise NotImplementedError(f"Unknown axis {axis}") 1064 1065 def std(self, field, weight=None): 1066 """Compute the standard deviation of a field. 1067 1068 This will, in a parallel-ware fashion, compute the standard 1069 deviation of the given field. 1070 1071 Parameters 1072 ---------- 1073 field : string or tuple field name 1074 The field to calculate the standard deviation of 1075 weight : string or tuple field name 1076 The field to weight the standard deviation calculation 1077 by. Defaults to unweighted if unset. 1078 1079 Returns 1080 ------- 1081 Scalar 1082 """ 1083 weight_field = sanitize_weight_field(self.ds, field, weight) 1084 return self.quantities.weighted_standard_deviation(field, weight_field)[0] 1085 1086 def ptp(self, field): 1087 r"""Compute the range of values (maximum - minimum) of a field. 1088 1089 This will, in a parallel-aware fashion, compute the "peak-to-peak" of 1090 the given field. 1091 1092 Parameters 1093 ---------- 1094 field : string or tuple field name 1095 The field to average. 1096 1097 Returns 1098 ------- 1099 Scalar 1100 1101 Examples 1102 -------- 1103 1104 >>> rho_range = reg.ptp(("gas", "density")) 1105 """ 1106 ex = self._compute_extrema(field) 1107 return ex[1] - ex[0] 1108 1109 def profile( 1110 self, 1111 bin_fields, 1112 fields, 1113 n_bins=64, 1114 extrema=None, 1115 logs=None, 1116 units=None, 1117 weight_field=("gas", "mass"), 1118 accumulation=False, 1119 fractional=False, 1120 deposition="ngp", 1121 ): 1122 r""" 1123 Create a 1, 2, or 3D profile object from this data_source. 1124 1125 The dimensionality of the profile object is chosen by the number of 1126 fields given in the bin_fields argument. This simply calls 1127 :func:`yt.data_objects.profiles.create_profile`. 1128 1129 Parameters 1130 ---------- 1131 bin_fields : list of strings 1132 List of the binning fields for profiling. 1133 fields : list of strings 1134 The fields to be profiled. 1135 n_bins : int or list of ints 1136 The number of bins in each dimension. If None, 64 bins for 1137 each bin are used for each bin field. 1138 Default: 64. 1139 extrema : dict of min, max tuples 1140 Minimum and maximum values of the bin_fields for the profiles. 1141 The keys correspond to the field names. Defaults to the extrema 1142 of the bin_fields of the dataset. If a units dict is provided, extrema 1143 are understood to be in the units specified in the dictionary. 1144 logs : dict of boolean values 1145 Whether or not to log the bin_fields for the profiles. 1146 The keys correspond to the field names. Defaults to the take_log 1147 attribute of the field. 1148 units : dict of strings 1149 The units of the fields in the profiles, including the bin_fields. 1150 weight_field : str or tuple field identifier 1151 The weight field for computing weighted average for the profile 1152 values. If None, the profile values are sums of the data in 1153 each bin. 1154 accumulation : bool or list of bools 1155 If True, the profile values for a bin n are the cumulative sum of 1156 all the values from bin 0 to n. If -True, the sum is reversed so 1157 that the value for bin n is the cumulative sum from bin N (total bins) 1158 to n. If the profile is 2D or 3D, a list of values can be given to 1159 control the summation in each dimension independently. 1160 Default: False. 1161 fractional : If True the profile values are divided by the sum of all 1162 the profile data such that the profile represents a probability 1163 distribution function. 1164 deposition : Controls the type of deposition used for ParticlePhasePlots. 1165 Valid choices are 'ngp' and 'cic'. Default is 'ngp'. This parameter is 1166 ignored the if the input fields are not of particle type. 1167 1168 1169 Examples 1170 -------- 1171 1172 Create a 1d profile. Access bin field from profile.x and field 1173 data from profile[<field_name>]. 1174 1175 >>> ds = load("DD0046/DD0046") 1176 >>> ad = ds.all_data() 1177 >>> profile = ad.profile( 1178 ... ad, 1179 ... [("gas", "density")], 1180 ... [("gas", "temperature"), ("gas", "velocity_x")], 1181 ... ) 1182 >>> print(profile.x) 1183 >>> print(profile["gas", "temperature"]) 1184 >>> plot = profile.plot() 1185 """ 1186 p = create_profile( 1187 self, 1188 bin_fields, 1189 fields, 1190 n_bins, 1191 extrema, 1192 logs, 1193 units, 1194 weight_field, 1195 accumulation, 1196 fractional, 1197 deposition, 1198 ) 1199 return p 1200 1201 def mean(self, field, axis=None, weight=None): 1202 r"""Compute the mean of a field, optionally along an axis, with a 1203 weight. 1204 1205 This will, in a parallel-aware fashion, compute the mean of the 1206 given field. If an axis is supplied, it will return a projection, 1207 where the weight is also supplied. By default the weight field will be 1208 "ones" or "particle_ones", depending on the field being averaged, 1209 resulting in an unweighted average. 1210 1211 Parameters 1212 ---------- 1213 field : string or tuple field name 1214 The field to average. 1215 axis : string, optional 1216 If supplied, the axis to compute the mean along (i.e., to project 1217 along) 1218 weight : string, optional 1219 The field to use as a weight. 1220 1221 Returns 1222 ------- 1223 Scalar or YTProjection. 1224 1225 Examples 1226 -------- 1227 1228 >>> avg_rho = reg.mean(("gas", "density"), weight="cell_volume") 1229 >>> rho_weighted_T = reg.mean( 1230 ... ("gas", "temperature"), axis=("index", "y"), weight=("gas", "density") 1231 ... ) 1232 """ 1233 weight_field = sanitize_weight_field(self.ds, field, weight) 1234 if axis in self.ds.coordinates.axis_name: 1235 r = self.ds.proj(field, axis, data_source=self, weight_field=weight_field) 1236 elif axis is None: 1237 r = self.quantities.weighted_average_quantity(field, weight_field) 1238 else: 1239 raise NotImplementedError(f"Unknown axis {axis}") 1240 return r 1241 1242 def sum(self, field, axis=None): 1243 r"""Compute the sum of a field, optionally along an axis. 1244 1245 This will, in a parallel-aware fashion, compute the sum of the given 1246 field. If an axis is specified, it will return a projection (using 1247 method type "sum", which does not take into account path length) along 1248 that axis. 1249 1250 Parameters 1251 ---------- 1252 field : string or tuple field name 1253 The field to sum. 1254 axis : string, optional 1255 If supplied, the axis to sum along. 1256 1257 Returns 1258 ------- 1259 Either a scalar or a YTProjection. 1260 1261 Examples 1262 -------- 1263 1264 >>> total_vol = reg.sum("cell_volume") 1265 >>> cell_count = reg.sum(("index", "ones"), axis=("index", "x")) 1266 """ 1267 # Because we're using ``sum`` to specifically mean a sum or a 1268 # projection with the method="sum", we do not utilize the ``mean`` 1269 # function. 1270 if axis in self.ds.coordinates.axis_name: 1271 with self._field_parameter_state({"axis": axis}): 1272 r = self.ds.proj(field, axis, data_source=self, method="sum") 1273 elif axis is None: 1274 r = self.quantities.total_quantity(field) 1275 else: 1276 raise NotImplementedError(f"Unknown axis {axis}") 1277 return r 1278 1279 def integrate(self, field, weight=None, axis=None): 1280 r"""Compute the integral (projection) of a field along an axis. 1281 1282 This projects a field along an axis. 1283 1284 Parameters 1285 ---------- 1286 field : string or tuple field name 1287 The field to project. 1288 weight : string or tuple field name 1289 The field to weight the projection by 1290 axis : string 1291 The axis to project along. 1292 1293 Returns 1294 ------- 1295 YTProjection 1296 1297 Examples 1298 -------- 1299 1300 >>> column_density = reg.integrate(("gas", "density"), axis=("index", "z")) 1301 """ 1302 if weight is not None: 1303 weight_field = sanitize_weight_field(self.ds, field, weight) 1304 else: 1305 weight_field = None 1306 if axis in self.ds.coordinates.axis_name: 1307 r = self.ds.proj(field, axis, data_source=self, weight_field=weight_field) 1308 else: 1309 raise NotImplementedError(f"Unknown axis {axis}") 1310 return r 1311 1312 @property 1313 def _hash(self): 1314 s = f"{self}" 1315 try: 1316 import hashlib 1317 1318 return hashlib.md5(s.encode("utf-8")).hexdigest() 1319 except ImportError: 1320 return s 1321 1322 def __reduce__(self): 1323 args = tuple( 1324 [self.ds._hash(), self._type_name] 1325 + [getattr(self, n) for n in self._con_args] 1326 + [self.field_parameters] 1327 ) 1328 return (_reconstruct_object, args) 1329 1330 def clone(self): 1331 r"""Clone a data object. 1332 1333 This will make a duplicate of a data object; note that the 1334 `field_parameters` may not necessarily be deeply-copied. If you modify 1335 the field parameters in-place, it may or may not be shared between the 1336 objects, depending on the type of object that that particular field 1337 parameter is. 1338 1339 Notes 1340 ----- 1341 One use case for this is to have multiple identical data objects that 1342 are being chunked over in different orders. 1343 1344 Examples 1345 -------- 1346 1347 >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") 1348 >>> sp = ds.sphere("c", 0.1) 1349 >>> sp_clone = sp.clone() 1350 >>> sp[("gas", "density")] 1351 >>> print(sp.field_data.keys()) 1352 [("gas", "density")] 1353 >>> print(sp_clone.field_data.keys()) 1354 [] 1355 """ 1356 args = self.__reduce__() 1357 return args[0](self.ds, *args[1][1:]) 1358 1359 def __repr__(self): 1360 # We'll do this the slow way to be clear what's going on 1361 s = f"{self.__class__.__name__} ({self.ds}): " 1362 for i in self._con_args: 1363 try: 1364 s += ", {}={}".format( 1365 i, 1366 getattr(self, i).in_base(unit_system=self.ds.unit_system), 1367 ) 1368 except AttributeError: 1369 s += f", {i}={getattr(self, i)}" 1370 return s 1371 1372 @contextmanager 1373 def _field_parameter_state(self, field_parameters): 1374 # What we're doing here is making a copy of the incoming field 1375 # parameters, and then updating it with our own. This means that we'll 1376 # be using our own center, if set, rather than the supplied one. But 1377 # it also means that any additionally set values can override it. 1378 old_field_parameters = self.field_parameters 1379 new_field_parameters = field_parameters.copy() 1380 new_field_parameters.update(old_field_parameters) 1381 self.field_parameters = new_field_parameters 1382 yield 1383 self.field_parameters = old_field_parameters 1384 1385 @contextmanager 1386 def _field_type_state(self, ftype, finfo, obj=None): 1387 if obj is None: 1388 obj = self 1389 old_particle_type = obj._current_particle_type 1390 old_fluid_type = obj._current_fluid_type 1391 fluid_types = self.ds.fluid_types 1392 if finfo.sampling_type == "particle" and ftype not in fluid_types: 1393 obj._current_particle_type = ftype 1394 else: 1395 obj._current_fluid_type = ftype 1396 yield 1397 obj._current_particle_type = old_particle_type 1398 obj._current_fluid_type = old_fluid_type 1399 1400 def _tupleize_field(self, field): 1401 1402 try: 1403 ftype, fname = field.name 1404 return ftype, fname 1405 except AttributeError: 1406 pass 1407 1408 if is_sequence(field) and not isinstance(field, str): 1409 try: 1410 ftype, fname = field 1411 if not all(isinstance(_, str) for _ in field): 1412 raise TypeError 1413 return ftype, fname 1414 except TypeError as e: 1415 raise YTFieldNotParseable(field) from e 1416 except ValueError: 1417 pass 1418 1419 try: 1420 fname = field 1421 finfo = self.ds._get_field_info(field) 1422 if finfo.sampling_type == "particle": 1423 ftype = self._current_particle_type 1424 if hasattr(self.ds, "_sph_ptypes"): 1425 ptypes = self.ds._sph_ptypes 1426 if finfo.name[0] in ptypes: 1427 ftype = finfo.name[0] 1428 elif finfo.alias_field and finfo.alias_name[0] in ptypes: 1429 ftype = self._current_fluid_type 1430 else: 1431 ftype = self._current_fluid_type 1432 if (ftype, fname) not in self.ds.field_info: 1433 ftype = self.ds._last_freq[0] 1434 return ftype, fname 1435 except YTFieldNotFound: 1436 pass 1437 1438 if isinstance(field, str): 1439 return "unknown", field 1440 1441 raise YTFieldNotParseable(field) 1442 1443 def _determine_fields(self, fields): 1444 explicit_fields = [] 1445 for field in iter_fields(fields): 1446 if field in self._container_fields: 1447 explicit_fields.append(field) 1448 continue 1449 1450 ftype, fname = self._tupleize_field(field) 1451 # print(field, " : ",ftype, fname) 1452 finfo = self.ds._get_field_info(ftype, fname) 1453 1454 # really ugly check to ensure that this field really does exist somewhere, 1455 # in some naming convention, before returning it as a possible field type 1456 if ( 1457 (ftype, fname) not in self.ds.field_info 1458 and (ftype, fname) not in self.ds.field_list 1459 and fname not in self.ds.field_list 1460 and (ftype, fname) not in self.ds.derived_field_list 1461 and fname not in self.ds.derived_field_list 1462 and (ftype, fname) not in self._container_fields 1463 ): 1464 raise YTFieldNotFound((ftype, fname), self.ds) 1465 1466 # these tests are really insufficient as a field type may be valid, and the 1467 # field name may be valid, but not the combination (field type, field name) 1468 particle_field = finfo.sampling_type == "particle" 1469 local_field = finfo.local_sampling 1470 if local_field: 1471 pass 1472 elif particle_field and ftype not in self.ds.particle_types: 1473 raise YTFieldTypeNotFound(ftype, ds=self.ds) 1474 elif not particle_field and ftype not in self.ds.fluid_types: 1475 raise YTFieldTypeNotFound(ftype, ds=self.ds) 1476 explicit_fields.append((ftype, fname)) 1477 return explicit_fields 1478 1479 _tree = None 1480 1481 @property 1482 def tiles(self): 1483 if self._tree is not None: 1484 return self._tree 1485 self._tree = AMRKDTree(self.ds, data_source=self) 1486 return self._tree 1487 1488 @property 1489 def blocks(self): 1490 for _io_chunk in self.chunks([], "io"): 1491 for _chunk in self.chunks([], "spatial", ngz=0): 1492 # For grids this will be a grid object, and for octrees it will 1493 # be an OctreeSubset. Note that we delegate to the sub-object. 1494 o = self._current_chunk.objs[0] 1495 cache_fp = o.field_parameters.copy() 1496 o.field_parameters.update(self.field_parameters) 1497 for b, m in o.select_blocks(self.selector): 1498 if m is None: 1499 continue 1500 yield b, m 1501 o.field_parameters = cache_fp 1502 1503 1504# PR3124: Given that save_as_dataset is now the recommended method for saving 1505# objects (see Issue 2021 and references therein), the following has been re-written. 1506# 1507# Original comments (still true): 1508# 1509# In the future, this would be better off being set up to more directly 1510# reference objects or retain state, perhaps with a context manager. 1511# 1512# One final detail: time series or multiple datasets in a single pickle 1513# seems problematic. 1514 1515 1516def _get_ds_by_hash(hash): 1517 from yt.data_objects.static_output import Dataset 1518 1519 if isinstance(hash, Dataset): 1520 return hash 1521 from yt.data_objects.static_output import _cached_datasets 1522 1523 for ds in _cached_datasets.values(): 1524 if ds._hash() == hash: 1525 return ds 1526 return None 1527 1528 1529def _reconstruct_object(*args, **kwargs): 1530 # returns a reconstructed YTDataContainer. As of PR 3124, we now return 1531 # the actual YTDataContainer rather than a (ds, YTDataContainer) tuple. 1532 1533 # pull out some arguments 1534 dsid = args[0] # the hash id 1535 dtype = args[1] # DataContainer type (e.g., 'region') 1536 field_parameters = args[-1] # the field parameters 1537 1538 # re-instantiate the base dataset from the hash and ParameterFileStore 1539 ds = _get_ds_by_hash(dsid) 1540 override_weakref = False 1541 if not ds: 1542 override_weakref = True 1543 datasets = ParameterFileStore() 1544 ds = datasets.get_ds_hash(dsid) 1545 1546 # instantiate the class with remainder of the args and adjust the state 1547 cls = getattr(ds, dtype) 1548 obj = cls(*args[2:-1]) 1549 obj.field_parameters.update(field_parameters) 1550 1551 # any nested ds references are weakref.proxy(ds), so need to ensure the ds 1552 # we just loaded persists when we leave this function (nosetests fail without 1553 # this) if we did not have an actual dataset as an argument. 1554 if hasattr(obj, "ds") and override_weakref: 1555 obj.ds = ds 1556 1557 return obj 1558