1import os 2import time 3import uuid 4import weakref 5from itertools import chain, product, repeat 6from numbers import Number as numeric_type 7 8import numpy as np 9from more_itertools import always_iterable 10 11from yt.data_objects.field_data import YTFieldData 12from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch 13from yt.data_objects.index_subobjects.octree_subset import OctreeSubset 14from yt.data_objects.index_subobjects.unstructured_mesh import ( 15 SemiStructuredMesh, 16 UnstructuredMesh, 17) 18from yt.data_objects.particle_unions import ParticleUnion 19from yt.data_objects.static_output import Dataset, ParticleFile 20from yt.data_objects.unions import MeshUnion 21from yt.frontends.sph.data_structures import SPHParticleIndex 22from yt.geometry.geometry_handler import YTDataChunk 23from yt.geometry.grid_geometry_handler import GridIndex 24from yt.geometry.oct_container import OctreeContainer 25from yt.geometry.oct_geometry_handler import OctreeIndex 26from yt.geometry.unstructured_mesh_handler import UnstructuredIndex 27from yt.units import YTQuantity 28from yt.utilities.io_handler import io_registry 29from yt.utilities.lib.cykdtree import PyKDTree 30from yt.utilities.lib.misc_utilities import get_box_grids_level 31from yt.utilities.lib.particle_kdtree_tools import ( 32 estimate_density, 33 generate_smoothing_length, 34) 35from yt.utilities.logger import ytLogger as mylog 36 37from .definitions import process_data, set_particle_types 38from .fields import StreamFieldInfo 39 40 41class StreamGrid(AMRGridPatch): 42 """ 43 Class representing a single In-memory Grid instance. 44 """ 45 46 __slots__ = ["proc_num"] 47 _id_offset = 0 48 49 def __init__(self, id, index): 50 """ 51 Returns an instance of StreamGrid with *id*, associated with *filename* 52 and *index*. 53 """ 54 # All of the field parameters will be passed to us as needed. 55 AMRGridPatch.__init__(self, id, filename=None, index=index) 56 self._children_ids = [] 57 self._parent_id = -1 58 self.Level = -1 59 60 def set_filename(self, filename): 61 pass 62 63 def __repr__(self): 64 return "StreamGrid_%04i" % (self.id) 65 66 @property 67 def Parent(self): 68 if self._parent_id == -1: 69 return None 70 return self.index.grids[self._parent_id - self._id_offset] 71 72 @property 73 def Children(self): 74 return [self.index.grids[cid - self._id_offset] for cid in self._children_ids] 75 76 77class StreamHandler: 78 def __init__( 79 self, 80 left_edges, 81 right_edges, 82 dimensions, 83 levels, 84 parent_ids, 85 particle_count, 86 processor_ids, 87 fields, 88 field_units, 89 code_units, 90 io=None, 91 particle_types=None, 92 periodicity=(True, True, True), 93 ): 94 if particle_types is None: 95 particle_types = {} 96 self.left_edges = np.array(left_edges) 97 self.right_edges = np.array(right_edges) 98 self.dimensions = dimensions 99 self.levels = levels 100 self.parent_ids = parent_ids 101 self.particle_count = particle_count 102 self.processor_ids = processor_ids 103 self.num_grids = self.levels.size 104 self.fields = fields 105 self.field_units = field_units 106 self.code_units = code_units 107 self.io = io 108 self.particle_types = particle_types 109 self.periodicity = periodicity 110 111 def get_fields(self): 112 return self.fields.all_fields 113 114 def get_particle_type(self, field): 115 116 if field in self.particle_types: 117 return self.particle_types[field] 118 else: 119 return False 120 121 122class StreamHierarchy(GridIndex): 123 124 grid = StreamGrid 125 126 def __init__(self, ds, dataset_type=None): 127 self.dataset_type = dataset_type 128 self.float_type = "float64" 129 self.dataset = weakref.proxy(ds) # for _obtain_enzo 130 self.stream_handler = ds.stream_handler 131 self.float_type = "float64" 132 self.directory = os.getcwd() 133 GridIndex.__init__(self, ds, dataset_type) 134 135 def _count_grids(self): 136 self.num_grids = self.stream_handler.num_grids 137 138 def _parse_index(self): 139 self.grid_dimensions = self.stream_handler.dimensions 140 self.grid_left_edge[:] = self.stream_handler.left_edges 141 self.grid_right_edge[:] = self.stream_handler.right_edges 142 self.grid_levels[:] = self.stream_handler.levels 143 self.min_level = self.grid_levels.min() 144 self.grid_procs = self.stream_handler.processor_ids 145 self.grid_particle_count[:] = self.stream_handler.particle_count 146 mylog.debug("Copying reverse tree") 147 self.grids = [] 148 # We enumerate, so it's 0-indexed id and 1-indexed pid 149 for id in range(self.num_grids): 150 self.grids.append(self.grid(id, self)) 151 self.grids[id].Level = self.grid_levels[id, 0] 152 parent_ids = self.stream_handler.parent_ids 153 if parent_ids is not None: 154 reverse_tree = self.stream_handler.parent_ids.tolist() 155 # Initial setup: 156 for gid, pid in enumerate(reverse_tree): 157 if pid >= 0: 158 self.grids[gid]._parent_id = pid 159 self.grids[pid]._children_ids.append(self.grids[gid].id) 160 else: 161 mylog.debug("Reconstructing parent-child relationships") 162 self._reconstruct_parent_child() 163 self.max_level = self.grid_levels.max() 164 mylog.debug("Preparing grids") 165 temp_grids = np.empty(self.num_grids, dtype="object") 166 for i, grid in enumerate(self.grids): 167 if (i % 1e4) == 0: 168 mylog.debug("Prepared % 7i / % 7i grids", i, self.num_grids) 169 grid.filename = None 170 grid._prepare_grid() 171 grid._setup_dx() 172 grid.proc_num = self.grid_procs[i] 173 temp_grids[i] = grid 174 self.grids = temp_grids 175 mylog.debug("Prepared") 176 177 def _reconstruct_parent_child(self): 178 mask = np.empty(len(self.grids), dtype="int32") 179 mylog.debug("First pass; identifying child grids") 180 for i, grid in enumerate(self.grids): 181 get_box_grids_level( 182 self.grid_left_edge[i, :], 183 self.grid_right_edge[i, :], 184 self.grid_levels[i] + 1, 185 self.grid_left_edge, 186 self.grid_right_edge, 187 self.grid_levels, 188 mask, 189 ) 190 ids = np.where(mask.astype("bool")) 191 grid._children_ids = ids[0] # where is a tuple 192 mylog.debug("Second pass; identifying parents") 193 self.stream_handler.parent_ids = ( 194 np.zeros(self.stream_handler.num_grids, "int64") - 1 195 ) 196 for i, grid in enumerate(self.grids): # Second pass 197 for child in grid.Children: 198 child._parent_id = i 199 # _id_offset = 0 200 self.stream_handler.parent_ids[child.id] = i 201 202 def _initialize_grid_arrays(self): 203 GridIndex._initialize_grid_arrays(self) 204 self.grid_procs = np.zeros((self.num_grids, 1), "int32") 205 206 def _detect_output_fields(self): 207 # NOTE: Because particle unions add to the actual field list, without 208 # having the keys in the field list itself, we need to double check 209 # here. 210 fl = set(self.stream_handler.get_fields()) 211 fl.update(set(getattr(self, "field_list", []))) 212 self.field_list = list(fl) 213 214 def _populate_grid_objects(self): 215 for g in self.grids: 216 g._setup_dx() 217 self.max_level = self.grid_levels.max() 218 219 def _setup_data_io(self): 220 if self.stream_handler.io is not None: 221 self.io = self.stream_handler.io 222 else: 223 self.io = io_registry[self.dataset_type](self.ds) 224 225 def _reset_particle_count(self): 226 self.grid_particle_count[:] = self.stream_handler.particle_count 227 for i, grid in enumerate(self.grids): 228 grid.NumberOfParticles = self.grid_particle_count[i, 0] 229 230 def update_data(self, data): 231 """ 232 Update the stream data with a new data dict. If fields already exist, 233 they will be replaced, but if they do not, they will be added. Fields 234 already in the stream but not part of the data dict will be left 235 alone. 236 """ 237 particle_types = set_particle_types(data[0]) 238 239 self.stream_handler.particle_types.update(particle_types) 240 self.ds._find_particle_types() 241 242 for i, grid in enumerate(self.grids): 243 field_units, gdata, number_of_particles = process_data(data[i]) 244 self.stream_handler.particle_count[i] = number_of_particles 245 self.stream_handler.field_units.update(field_units) 246 for field in gdata: 247 if field in grid.field_data: 248 grid.field_data.pop(field, None) 249 self.stream_handler.fields[grid.id][field] = gdata[field] 250 251 self._reset_particle_count() 252 # We only want to create a superset of fields here. 253 for field in self.ds.field_list: 254 if field[0] == "all": 255 self.ds.field_list.remove(field) 256 self._detect_output_fields() 257 self.ds.create_field_info() 258 mylog.debug("Creating Particle Union 'all'") 259 pu = ParticleUnion("all", list(self.ds.particle_types_raw)) 260 self.ds.add_particle_union(pu) 261 self.ds.particle_types = tuple(set(self.ds.particle_types)) 262 263 264class StreamDataset(Dataset): 265 _index_class = StreamHierarchy 266 _field_info_class = StreamFieldInfo 267 _dataset_type = "stream" 268 269 def __init__( 270 self, 271 stream_handler, 272 storage_filename=None, 273 geometry="cartesian", 274 unit_system="cgs", 275 default_species_fields=None, 276 ): 277 self.fluid_types += ("stream",) 278 self.geometry = geometry 279 self.stream_handler = stream_handler 280 self._find_particle_types() 281 name = f"InMemoryParameterFile_{uuid.uuid4().hex}" 282 from yt.data_objects.static_output import _cached_datasets 283 284 _cached_datasets[name] = self 285 Dataset.__init__( 286 self, 287 name, 288 self._dataset_type, 289 unit_system=unit_system, 290 default_species_fields=default_species_fields, 291 ) 292 293 def _parse_parameter_file(self): 294 self.basename = self.stream_handler.name 295 self.parameters["CurrentTimeIdentifier"] = time.time() 296 self.unique_identifier = self.parameters["CurrentTimeIdentifier"] 297 self.domain_left_edge = self.stream_handler.domain_left_edge.copy() 298 self.domain_right_edge = self.stream_handler.domain_right_edge.copy() 299 self.refine_by = self.stream_handler.refine_by 300 self.dimensionality = self.stream_handler.dimensionality 301 self._periodicity = self.stream_handler.periodicity 302 self.domain_dimensions = self.stream_handler.domain_dimensions 303 self.current_time = self.stream_handler.simulation_time 304 self.gamma = 5.0 / 3.0 305 self.parameters["EOSType"] = -1 306 self.parameters["CosmologyHubbleConstantNow"] = 1.0 307 self.parameters["CosmologyCurrentRedshift"] = 1.0 308 self.parameters["HydroMethod"] = -1 309 if self.stream_handler.cosmology_simulation: 310 self.cosmological_simulation = 1 311 self.current_redshift = self.stream_handler.current_redshift 312 self.omega_lambda = self.stream_handler.omega_lambda 313 self.omega_matter = self.stream_handler.omega_matter 314 self.hubble_constant = self.stream_handler.hubble_constant 315 else: 316 self.current_redshift = 0.0 317 self.omega_lambda = 0.0 318 self.omega_matter = 0.0 319 self.hubble_constant = 0.0 320 self.cosmological_simulation = 0 321 322 def _set_units(self): 323 self.field_units = self.stream_handler.field_units 324 325 def _set_code_unit_attributes(self): 326 base_units = self.stream_handler.code_units 327 attrs = ( 328 "length_unit", 329 "mass_unit", 330 "time_unit", 331 "velocity_unit", 332 "magnetic_unit", 333 ) 334 cgs_units = ("cm", "g", "s", "cm/s", "gauss") 335 for unit, attr, cgs_unit in zip(base_units, attrs, cgs_units): 336 if isinstance(unit, str): 337 uq = self.quan(1.0, unit) 338 elif isinstance(unit, numeric_type): 339 uq = self.quan(unit, cgs_unit) 340 elif isinstance(unit, YTQuantity): 341 uq = unit 342 elif isinstance(unit, tuple): 343 uq = self.quan(unit[0], unit[1]) 344 else: 345 raise RuntimeError(f"{attr} ({unit}) is invalid.") 346 setattr(self, attr, uq) 347 348 @classmethod 349 def _is_valid(cls, filename, *args, **kwargs): 350 return False 351 352 @property 353 def _skip_cache(self): 354 return True 355 356 def _find_particle_types(self): 357 particle_types = set() 358 for k, v in self.stream_handler.particle_types.items(): 359 if v: 360 particle_types.add(k[0]) 361 self.particle_types = tuple(particle_types) 362 self.particle_types_raw = self.particle_types 363 364 365class StreamDictFieldHandler(dict): 366 _additional_fields = () 367 368 @property 369 def all_fields(self): 370 self_fields = chain.from_iterable(s.keys() for s in self.values()) 371 self_fields = list(set(self_fields)) 372 fields = list(self._additional_fields) + self_fields 373 fields = list(set(fields)) 374 return fields 375 376 377class StreamParticleIndex(SPHParticleIndex): 378 def __init__(self, ds, dataset_type=None): 379 self.stream_handler = ds.stream_handler 380 super().__init__(ds, dataset_type) 381 382 def _setup_data_io(self): 383 if self.stream_handler.io is not None: 384 self.io = self.stream_handler.io 385 else: 386 self.io = io_registry[self.dataset_type](self.ds) 387 388 def update_data(self, data): 389 """ 390 Update the stream data with a new data dict. If fields already exist, 391 they will be replaced, but if they do not, they will be added. Fields 392 already in the stream but not part of the data dict will be left 393 alone. 394 """ 395 # Alias 396 ds = self.ds 397 handler = ds.stream_handler 398 399 # Preprocess 400 field_units, data, _ = process_data(data) 401 pdata = {} 402 for key in data.keys(): 403 if not isinstance(key, tuple): 404 field = ("io", key) 405 mylog.debug("Reassigning '%s' to '%s'", key, field) 406 else: 407 field = key 408 pdata[field] = data[key] 409 data = pdata # Drop reference count 410 particle_types = set_particle_types(data) 411 412 # Update particle types 413 handler.particle_types.update(particle_types) 414 ds._find_particle_types() 415 416 # Update fields 417 handler.field_units.update(field_units) 418 fields = handler.fields 419 for field in data.keys(): 420 if field not in fields._additional_fields: 421 fields._additional_fields += (field,) 422 fields["stream_file"].update(data) 423 424 # Update field list 425 for field in self.ds.field_list: 426 if field[0] in ["all", "nbody"]: 427 self.ds.field_list.remove(field) 428 self._detect_output_fields() 429 self.ds.create_field_info() 430 431 432class StreamParticleFile(ParticleFile): 433 pass 434 435 436class StreamParticlesDataset(StreamDataset): 437 _index_class = StreamParticleIndex 438 _file_class = StreamParticleFile 439 _field_info_class = StreamFieldInfo 440 _dataset_type = "stream_particles" 441 file_count = 1 442 filename_template = "stream_file" 443 _proj_type = "particle_proj" 444 445 def __init__( 446 self, 447 stream_handler, 448 storage_filename=None, 449 geometry="cartesian", 450 unit_system="cgs", 451 default_species_fields=None, 452 ): 453 super().__init__( 454 stream_handler, 455 storage_filename=storage_filename, 456 geometry=geometry, 457 unit_system=unit_system, 458 default_species_fields=default_species_fields, 459 ) 460 fields = list(stream_handler.fields["stream_file"].keys()) 461 # This is the current method of detecting SPH data. 462 # This should be made more flexible in the future. 463 if ("io", "density") in fields and ("io", "smoothing_length") in fields: 464 self._sph_ptypes = ("io",) 465 466 def add_sph_fields(self, n_neighbors=32, kernel="cubic", sph_ptype="io"): 467 """Add SPH fields for the specified particle type. 468 469 For a particle type with "particle_position" and "particle_mass" already 470 defined, this method adds the "smoothing_length" and "density" fields. 471 "smoothing_length" is computed as the distance to the nth nearest 472 neighbor. "density" is computed as the SPH (gather) smoothed mass. The 473 SPH fields are added only if they don't already exist. 474 475 Parameters 476 ---------- 477 n_neighbors : int 478 The number of neighbors to use in smoothing length computation. 479 kernel : str 480 The kernel function to use in density estimation. 481 sph_ptype : str 482 The SPH particle type. Each dataset has one sph_ptype only. This 483 method will overwrite existing sph_ptype of the dataset. 484 485 """ 486 mylog.info("Generating SPH fields") 487 488 # Unify units 489 l_unit = "code_length" 490 m_unit = "code_mass" 491 d_unit = "code_mass / code_length**3" 492 493 # Read basic fields 494 ad = self.all_data() 495 pos = ad[sph_ptype, "particle_position"].to(l_unit).d 496 mass = ad[sph_ptype, "particle_mass"].to(m_unit).d 497 498 # Construct k-d tree 499 kdtree = PyKDTree( 500 pos.astype("float64"), 501 left_edge=self.domain_left_edge.to_value(l_unit), 502 right_edge=self.domain_right_edge.to_value(l_unit), 503 periodic=self.periodicity, 504 leafsize=2 * int(n_neighbors), 505 ) 506 order = np.argsort(kdtree.idx) 507 508 def exists(fname): 509 if (sph_ptype, fname) in self.derived_field_list: 510 mylog.info( 511 "Field ('%s','%s') already exists. Skipping", sph_ptype, fname 512 ) 513 return True 514 else: 515 mylog.info("Generating field ('%s','%s')", sph_ptype, fname) 516 return False 517 518 data = {} 519 520 # Add smoothing length field 521 fname = "smoothing_length" 522 if not exists(fname): 523 hsml = generate_smoothing_length(pos[kdtree.idx], kdtree, n_neighbors) 524 hsml = hsml[order] 525 data[(sph_ptype, "smoothing_length")] = (hsml, l_unit) 526 else: 527 hsml = ad[sph_ptype, fname].to(l_unit).d 528 529 # Add density field 530 fname = "density" 531 if not exists(fname): 532 dens = estimate_density( 533 pos[kdtree.idx], 534 mass[kdtree.idx], 535 hsml[kdtree.idx], 536 kdtree, 537 kernel_name=kernel, 538 ) 539 dens = dens[order] 540 data[(sph_ptype, "density")] = (dens, d_unit) 541 542 # Add fields 543 self._sph_ptypes = (sph_ptype,) 544 self.index.update_data(data) 545 self.num_neighbors = n_neighbors 546 547 548_cis = np.fromiter( 549 chain.from_iterable(product([0, 1], [0, 1], [0, 1])), dtype=np.int64, count=8 * 3 550) 551_cis.shape = (8, 3) 552 553 554def hexahedral_connectivity(xgrid, ygrid, zgrid): 555 r"""Define the cell coordinates and cell neighbors of a hexahedral mesh 556 for a semistructured grid. Used to specify the connectivity and 557 coordinates parameters used in 558 :func:`~yt.frontends.stream.data_structures.load_hexahedral_mesh`. 559 560 Parameters 561 ---------- 562 xgrid : array_like 563 x-coordinates of boundaries of the hexahedral cells. Should be a 564 one-dimensional array. 565 ygrid : array_like 566 y-coordinates of boundaries of the hexahedral cells. Should be a 567 one-dimensional array. 568 zgrid : array_like 569 z-coordinates of boundaries of the hexahedral cells. Should be a 570 one-dimensional array. 571 572 Returns 573 ------- 574 coords : array_like 575 The list of (x,y,z) coordinates of the vertices of the mesh. 576 Is of size (M,3) where M is the number of vertices. 577 connectivity : array_like 578 For each hexahedron h in the mesh, gives the index of each of h's 579 neighbors. Is of size (N,8), where N is the number of hexahedra. 580 581 Examples 582 -------- 583 584 >>> xgrid = np.array([-1, -0.25, 0, 0.25, 1]) 585 >>> coords, conn = hexahedral_connectivity(xgrid, xgrid, xgrid) 586 >>> coords 587 array([[-1. , -1. , -1. ], 588 [-1. , -1. , -0.25], 589 [-1. , -1. , 0. ], 590 ..., 591 [ 1. , 1. , 0. ], 592 [ 1. , 1. , 0.25], 593 [ 1. , 1. , 1. ]]) 594 595 >>> conn 596 array([[ 0, 1, 5, 6, 25, 26, 30, 31], 597 [ 1, 2, 6, 7, 26, 27, 31, 32], 598 [ 2, 3, 7, 8, 27, 28, 32, 33], 599 ..., 600 [ 91, 92, 96, 97, 116, 117, 121, 122], 601 [ 92, 93, 97, 98, 117, 118, 122, 123], 602 [ 93, 94, 98, 99, 118, 119, 123, 124]]) 603 """ 604 nx = len(xgrid) 605 ny = len(ygrid) 606 nz = len(zgrid) 607 coords = np.zeros((nx, ny, nz, 3), dtype="float64", order="C") 608 coords[:, :, :, 0] = xgrid[:, None, None] 609 coords[:, :, :, 1] = ygrid[None, :, None] 610 coords[:, :, :, 2] = zgrid[None, None, :] 611 coords.shape = (nx * ny * nz, 3) 612 cycle = np.rollaxis(np.indices((nx - 1, ny - 1, nz - 1)), 0, 4) 613 cycle.shape = ((nx - 1) * (ny - 1) * (nz - 1), 3) 614 off = _cis + cycle[:, np.newaxis] 615 connectivity = np.array( 616 ((off[:, :, 0] * ny) + off[:, :, 1]) * nz + off[:, :, 2], order="C" 617 ) 618 return coords, connectivity 619 620 621class StreamHexahedralMesh(SemiStructuredMesh): 622 _connectivity_length = 8 623 _index_offset = 0 624 625 626class StreamHexahedralHierarchy(UnstructuredIndex): 627 def __init__(self, ds, dataset_type=None): 628 self.stream_handler = ds.stream_handler 629 super().__init__(ds, dataset_type) 630 631 def _initialize_mesh(self): 632 coords = self.stream_handler.fields.pop("coordinates") 633 connect = self.stream_handler.fields.pop("connectivity") 634 self.meshes = [ 635 StreamHexahedralMesh(0, self.index_filename, connect, coords, self) 636 ] 637 638 def _setup_data_io(self): 639 if self.stream_handler.io is not None: 640 self.io = self.stream_handler.io 641 else: 642 self.io = io_registry[self.dataset_type](self.ds) 643 644 def _detect_output_fields(self): 645 self.field_list = list(set(self.stream_handler.get_fields())) 646 647 648class StreamHexahedralDataset(StreamDataset): 649 _index_class = StreamHexahedralHierarchy 650 _field_info_class = StreamFieldInfo 651 _dataset_type = "stream_hexahedral" 652 653 654class StreamOctreeSubset(OctreeSubset): 655 domain_id = 1 656 _domain_offset = 1 657 658 def __init__( 659 self, base_region, ds, oct_handler, over_refine_factor=1, num_ghost_zones=0 660 ): 661 self._over_refine_factor = over_refine_factor 662 self._num_zones = 1 << (over_refine_factor) 663 self.field_data = YTFieldData() 664 self.field_parameters = {} 665 self.ds = ds 666 self.oct_handler = oct_handler 667 self._last_mask = None 668 self._last_selector_id = None 669 self._current_particle_type = "io" 670 self._current_fluid_type = self.ds.default_fluid_type 671 self.base_region = base_region 672 self.base_selector = base_region.selector 673 674 self._num_ghost_zones = num_ghost_zones 675 676 if num_ghost_zones > 0: 677 if not all(ds.periodicity): 678 mylog.warning( 679 "Ghost zones will wrongly assume the domain to be periodic." 680 ) 681 base_grid = StreamOctreeSubset( 682 base_region, ds, oct_handler, over_refine_factor 683 ) 684 self._base_grid = base_grid 685 686 def retrieve_ghost_zones(self, ngz, fields, smoothed=False): 687 try: 688 new_subset = self._subset_with_gz 689 mylog.debug("Reusing previous subset with ghost zone.") 690 except AttributeError: 691 new_subset = StreamOctreeSubset( 692 self.base_region, 693 self.ds, 694 self.oct_handler, 695 self._over_refine_factor, 696 num_ghost_zones=ngz, 697 ) 698 self._subset_with_gz = new_subset 699 700 return new_subset 701 702 def _fill_no_ghostzones(self, content, dest, selector, offset): 703 # Here we get a copy of the file, which we skip through and read the 704 # bits we want. 705 oct_handler = self.oct_handler 706 cell_count = selector.count_oct_cells(self.oct_handler, self.domain_id) 707 levels, cell_inds, file_inds = self.oct_handler.file_index_octs( 708 selector, self.domain_id, cell_count 709 ) 710 levels[:] = 0 711 dest.update((field, np.empty(cell_count, dtype="float64")) for field in content) 712 # Make references ... 713 count = oct_handler.fill_level( 714 0, levels, cell_inds, file_inds, dest, content, offset 715 ) 716 return count 717 718 def _fill_with_ghostzones(self, content, dest, selector, offset): 719 oct_handler = self.oct_handler 720 ndim = self.ds.dimensionality 721 cell_count = ( 722 selector.count_octs(self.oct_handler, self.domain_id) * self.nz ** ndim 723 ) 724 725 gz_cache = getattr(self, "_ghost_zone_cache", None) 726 if gz_cache: 727 levels, cell_inds, file_inds, domains = gz_cache 728 else: 729 gz_cache = ( 730 levels, 731 cell_inds, 732 file_inds, 733 domains, 734 ) = oct_handler.file_index_octs_with_ghost_zones( 735 selector, self.domain_id, cell_count 736 ) 737 self._ghost_zone_cache = gz_cache 738 levels[:] = 0 739 dest.update((field, np.empty(cell_count, dtype="float64")) for field in content) 740 # Make references ... 741 oct_handler.fill_level(0, levels, cell_inds, file_inds, dest, content, offset) 742 743 def fill(self, content, dest, selector, offset): 744 if self._num_ghost_zones == 0: 745 return self._fill_no_ghostzones(content, dest, selector, offset) 746 else: 747 return self._fill_with_ghostzones(content, dest, selector, offset) 748 749 750class StreamOctreeHandler(OctreeIndex): 751 def __init__(self, ds, dataset_type=None): 752 self.stream_handler = ds.stream_handler 753 self.dataset_type = dataset_type 754 super().__init__(ds, dataset_type) 755 756 def _setup_data_io(self): 757 if self.stream_handler.io is not None: 758 self.io = self.stream_handler.io 759 else: 760 self.io = io_registry[self.dataset_type](self.ds) 761 762 def _initialize_oct_handler(self): 763 header = dict( 764 dims=[1, 1, 1], 765 left_edge=self.ds.domain_left_edge, 766 right_edge=self.ds.domain_right_edge, 767 octree=self.ds.octree_mask, 768 over_refine=self.ds.over_refine_factor, 769 partial_coverage=self.ds.partial_coverage, 770 ) 771 self.oct_handler = OctreeContainer.load_octree(header) 772 773 def _identify_base_chunk(self, dobj): 774 if getattr(dobj, "_chunk_info", None) is None: 775 base_region = getattr(dobj, "base_region", dobj) 776 subset = [ 777 StreamOctreeSubset( 778 base_region, 779 self.dataset, 780 self.oct_handler, 781 self.ds.over_refine_factor, 782 ) 783 ] 784 dobj._chunk_info = subset 785 dobj._current_chunk = list(self._chunk_all(dobj))[0] 786 787 def _chunk_all(self, dobj): 788 oobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info) 789 yield YTDataChunk(dobj, "all", oobjs, None) 790 791 def _chunk_spatial(self, dobj, ngz, sort=None, preload_fields=None): 792 sobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info) 793 # This is where we will perform cutting of the Octree and 794 # load-balancing. That may require a specialized selector object to 795 # cut based on some space-filling curve index. 796 for og in sobjs: 797 if ngz > 0: 798 g = og.retrieve_ghost_zones(ngz, [], smoothed=True) 799 else: 800 g = og 801 yield YTDataChunk(dobj, "spatial", [g]) 802 803 def _chunk_io(self, dobj, cache=True, local_only=False): 804 oobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info) 805 for subset in oobjs: 806 yield YTDataChunk(dobj, "io", [subset], None, cache=cache) 807 808 def _setup_classes(self): 809 dd = self._get_data_reader_dict() 810 super()._setup_classes(dd) 811 812 def _detect_output_fields(self): 813 # NOTE: Because particle unions add to the actual field list, without 814 # having the keys in the field list itself, we need to double check 815 # here. 816 fl = set(self.stream_handler.get_fields()) 817 fl.update(set(getattr(self, "field_list", []))) 818 self.field_list = list(fl) 819 820 821class StreamOctreeDataset(StreamDataset): 822 _index_class = StreamOctreeHandler 823 _field_info_class = StreamFieldInfo 824 _dataset_type = "stream_octree" 825 826 levelmax = None 827 828 def __init__( 829 self, 830 stream_handler, 831 storage_filename=None, 832 geometry="cartesian", 833 unit_system="cgs", 834 default_species_fields=None, 835 ): 836 super().__init__( 837 stream_handler, 838 storage_filename, 839 geometry, 840 unit_system, 841 default_species_fields=default_species_fields, 842 ) 843 # Set up levelmax 844 self.max_level = stream_handler.levels.max() 845 self.min_level = stream_handler.levels.min() 846 847 848class StreamUnstructuredMesh(UnstructuredMesh): 849 _index_offset = 0 850 851 def __init__(self, *args, **kwargs): 852 super().__init__(*args, **kwargs) 853 self._connectivity_length = self.connectivity_indices.shape[1] 854 855 856class StreamUnstructuredIndex(UnstructuredIndex): 857 def __init__(self, ds, dataset_type=None): 858 self.stream_handler = ds.stream_handler 859 super().__init__(ds, dataset_type) 860 861 def _initialize_mesh(self): 862 coords = self.stream_handler.fields.pop("coordinates") 863 connect = always_iterable(self.stream_handler.fields.pop("connectivity")) 864 865 self.meshes = [ 866 StreamUnstructuredMesh(i, self.index_filename, c1, c2, self) 867 for i, (c1, c2) in enumerate(zip(connect, repeat(coords))) 868 ] 869 self.mesh_union = MeshUnion("mesh_union", self.meshes) 870 871 def _setup_data_io(self): 872 if self.stream_handler.io is not None: 873 self.io = self.stream_handler.io 874 else: 875 self.io = io_registry[self.dataset_type](self.ds) 876 877 def _detect_output_fields(self): 878 self.field_list = list(set(self.stream_handler.get_fields())) 879 fnames = list({fn for ft, fn in self.field_list}) 880 self.field_list += [("all", fname) for fname in fnames] 881 882 883class StreamUnstructuredMeshDataset(StreamDataset): 884 _index_class = StreamUnstructuredIndex 885 _field_info_class = StreamFieldInfo 886 _dataset_type = "stream_unstructured" 887 888 def _find_particle_types(self): 889 pass 890