1from collections import defaultdict 2 3import numpy as np 4 5from yt.funcs import is_sequence 6from yt.geometry.grid_container import GridTree, MatchPointsToGrids 7from yt.utilities.exceptions import ( 8 YTInconsistentGridFieldShape, 9 YTInconsistentGridFieldShapeGridDims, 10 YTInconsistentParticleFieldShape, 11) 12from yt.utilities.logger import ytLogger as mylog 13 14from .fields import StreamFieldInfo 15 16 17def assign_particle_data(ds, pdata, bbox): 18 19 """ 20 Assign particle data to the grids using MatchPointsToGrids. This 21 will overwrite any existing particle data, so be careful! 22 """ 23 24 for ptype in ds.particle_types_raw: 25 check_fields = [(ptype, "particle_position_x"), (ptype, "particle_position")] 26 if all(f not in pdata for f in check_fields): 27 pdata_ftype = {} 28 for f in [k for k in sorted(pdata)]: 29 if not hasattr(pdata[f], "shape"): 30 continue 31 if f == "number_of_particles": 32 continue 33 mylog.debug("Reassigning '%s' to ('%s','%s')", f, ptype, f) 34 pdata_ftype[ptype, f] = pdata.pop(f) 35 pdata_ftype.update(pdata) 36 pdata = pdata_ftype 37 38 # Note: what we need to do here is a bit tricky. Because occasionally this 39 # gets called before we property handle the field detection, we cannot use 40 # any information about the index. Fortunately for us, we can generate 41 # most of the GridTree utilizing information we already have from the 42 # stream handler. 43 44 if len(ds.stream_handler.fields) > 1: 45 pdata.pop("number_of_particles", None) 46 num_grids = len(ds.stream_handler.fields) 47 parent_ids = ds.stream_handler.parent_ids 48 num_children = np.zeros(num_grids, dtype="int64") 49 # We're going to do this the slow way 50 mask = np.empty(num_grids, dtype="bool") 51 for i in range(num_grids): 52 np.equal(parent_ids, i, mask) 53 num_children[i] = mask.sum() 54 levels = ds.stream_handler.levels.astype("int64").ravel() 55 grid_tree = GridTree( 56 num_grids, 57 ds.stream_handler.left_edges, 58 ds.stream_handler.right_edges, 59 ds.stream_handler.dimensions, 60 ds.stream_handler.parent_ids, 61 levels, 62 num_children, 63 ) 64 65 grid_pdata = [] 66 for _ in range(num_grids): 67 grid = {"number_of_particles": 0} 68 grid_pdata.append(grid) 69 70 for ptype in ds.particle_types_raw: 71 if (ptype, "particle_position_x") in pdata: 72 x, y, z = (pdata[ptype, f"particle_position_{ax}"] for ax in "xyz") 73 elif (ptype, "particle_position") in pdata: 74 x, y, z = pdata[ptype, "particle_position"].T 75 else: 76 raise KeyError( 77 "Cannot decompose particle data without position fields!" 78 ) 79 pts = MatchPointsToGrids(grid_tree, len(x), x, y, z) 80 particle_grid_inds = pts.find_points_in_tree() 81 (assigned_particles,) = (particle_grid_inds >= 0).nonzero() 82 num_particles = particle_grid_inds.size 83 num_unassigned = num_particles - assigned_particles.size 84 if num_unassigned > 0: 85 eps = np.finfo(x.dtype).eps 86 s = np.array( 87 [ 88 [x.min() - eps, x.max() + eps], 89 [y.min() - eps, y.max() + eps], 90 [z.min() - eps, z.max() + eps], 91 ] 92 ) 93 sug_bbox = [ 94 [min(bbox[0, 0], s[0, 0]), max(bbox[0, 1], s[0, 1])], 95 [min(bbox[1, 0], s[1, 0]), max(bbox[1, 1], s[1, 1])], 96 [min(bbox[2, 0], s[2, 0]), max(bbox[2, 1], s[2, 1])], 97 ] 98 mylog.warning( 99 "Discarding %s particles (out of %s) that are outside " 100 "bounding box. Set bbox=%s to avoid this in the future.", 101 num_unassigned, 102 num_particles, 103 sug_bbox, 104 ) 105 particle_grid_inds = particle_grid_inds[assigned_particles] 106 x = x[assigned_particles] 107 y = y[assigned_particles] 108 z = z[assigned_particles] 109 idxs = np.argsort(particle_grid_inds) 110 particle_grid_count = np.bincount( 111 particle_grid_inds.astype("intp"), minlength=num_grids 112 ) 113 particle_indices = np.zeros(num_grids + 1, dtype="int64") 114 if num_grids > 1: 115 np.add.accumulate( 116 particle_grid_count.squeeze(), out=particle_indices[1:] 117 ) 118 else: 119 particle_indices[1] = particle_grid_count.squeeze() 120 for i, pcount in enumerate(particle_grid_count): 121 grid_pdata[i]["number_of_particles"] += pcount 122 start = particle_indices[i] 123 end = particle_indices[i + 1] 124 for key in pdata.keys(): 125 if key[0] == ptype: 126 grid_pdata[i][key] = pdata[key][idxs][start:end] 127 128 else: 129 grid_pdata = [pdata] 130 131 for pd, gi in zip(grid_pdata, sorted(ds.stream_handler.fields)): 132 ds.stream_handler.fields[gi].update(pd) 133 ds.stream_handler.particle_types.update(set_particle_types(pd)) 134 npart = ds.stream_handler.fields[gi].pop("number_of_particles", 0) 135 ds.stream_handler.particle_count[gi] = npart 136 137 138def process_data(data, grid_dims=None): 139 new_data, field_units = {}, {} 140 for field, val in data.items(): 141 # val is a data array 142 if isinstance(val, np.ndarray): 143 # val is a YTArray 144 if hasattr(val, "units"): 145 field_units[field] = val.units 146 new_data[field] = val.copy().d 147 # val is a numpy array 148 else: 149 field_units[field] = "" 150 new_data[field] = val.copy() 151 152 # val is a tuple of (data, units) 153 elif isinstance(val, tuple) and len(val) == 2: 154 try: 155 assert isinstance(field, (str, tuple)), "Field name is not a string!" 156 assert isinstance(val[0], np.ndarray), "Field data is not an ndarray!" 157 assert isinstance(val[1], str), "Unit specification is not a string!" 158 field_units[field] = val[1] 159 new_data[field] = val[0] 160 except AssertionError as e: 161 raise RuntimeError("The data dict appears to be invalid.\n" + str(e)) 162 163 # val is a list of data to be turned into an array 164 elif is_sequence(val): 165 field_units[field] = "" 166 new_data[field] = np.asarray(val) 167 168 else: 169 raise RuntimeError( 170 "The data dict appears to be invalid. " 171 "The data dictionary must map from field " 172 "names to (numpy array, unit spec) tuples. " 173 ) 174 175 data = new_data 176 177 # At this point, we have arrays for all our fields 178 new_data = {} 179 for field in data: 180 n_shape = len(data[field].shape) 181 if isinstance(field, tuple): 182 new_field = field 183 elif n_shape in (1, 2): 184 new_field = ("io", field) 185 elif n_shape == 3: 186 new_field = ("stream", field) 187 else: 188 raise RuntimeError 189 new_data[new_field] = data[field] 190 field_units[new_field] = field_units.pop(field) 191 known_fields = ( 192 StreamFieldInfo.known_particle_fields + StreamFieldInfo.known_other_fields 193 ) 194 # We do not want to override any of the known ones, if it's not 195 # overridden here. 196 if ( 197 any(f[0] == new_field[1] for f in known_fields) 198 and field_units[new_field] == "" 199 ): 200 field_units.pop(new_field) 201 data = new_data 202 # Sanity checking that all fields have the same dimensions. 203 g_shapes = [] 204 p_shapes = defaultdict(list) 205 for field in data: 206 f_shape = data[field].shape 207 n_shape = len(f_shape) 208 if n_shape in (1, 2): 209 p_shapes[field[0]].append((field[1], f_shape[0])) 210 elif n_shape == 3: 211 g_shapes.append((field, f_shape)) 212 if len(g_shapes) > 0: 213 g_s = np.array([s[1] for s in g_shapes]) 214 if not np.all(g_s == g_s[0]): 215 raise YTInconsistentGridFieldShape(g_shapes) 216 if grid_dims is not None: 217 if not np.all(g_s == grid_dims): 218 raise YTInconsistentGridFieldShapeGridDims(g_shapes, grid_dims) 219 if len(p_shapes) > 0: 220 for ptype, p_shape in p_shapes.items(): 221 p_s = np.array([s[1] for s in p_shape]) 222 if not np.all(p_s == p_s[0]): 223 raise YTInconsistentParticleFieldShape(ptype, p_shape) 224 # Now that we know the particle fields are consistent, determine the number 225 # of particles. 226 if len(p_shapes) > 0: 227 number_of_particles = np.sum([s[0][1] for s in p_shapes.values()]) 228 else: 229 number_of_particles = 0 230 return field_units, data, number_of_particles 231 232 233def set_particle_types(data): 234 particle_types = {} 235 for key in data.keys(): 236 if key == "number_of_particles": 237 continue 238 if len(data[key].shape) == 1: 239 particle_types[key] = True 240 else: 241 particle_types[key] = False 242 return particle_types 243