1from collections import defaultdict
2
3import numpy as np
4
5from yt.funcs import is_sequence
6from yt.geometry.grid_container import GridTree, MatchPointsToGrids
7from yt.utilities.exceptions import (
8    YTInconsistentGridFieldShape,
9    YTInconsistentGridFieldShapeGridDims,
10    YTInconsistentParticleFieldShape,
11)
12from yt.utilities.logger import ytLogger as mylog
13
14from .fields import StreamFieldInfo
15
16
17def assign_particle_data(ds, pdata, bbox):
18
19    """
20    Assign particle data to the grids using MatchPointsToGrids. This
21    will overwrite any existing particle data, so be careful!
22    """
23
24    for ptype in ds.particle_types_raw:
25        check_fields = [(ptype, "particle_position_x"), (ptype, "particle_position")]
26        if all(f not in pdata for f in check_fields):
27            pdata_ftype = {}
28            for f in [k for k in sorted(pdata)]:
29                if not hasattr(pdata[f], "shape"):
30                    continue
31                if f == "number_of_particles":
32                    continue
33                mylog.debug("Reassigning '%s' to ('%s','%s')", f, ptype, f)
34                pdata_ftype[ptype, f] = pdata.pop(f)
35            pdata_ftype.update(pdata)
36            pdata = pdata_ftype
37
38    # Note: what we need to do here is a bit tricky.  Because occasionally this
39    # gets called before we property handle the field detection, we cannot use
40    # any information about the index.  Fortunately for us, we can generate
41    # most of the GridTree utilizing information we already have from the
42    # stream handler.
43
44    if len(ds.stream_handler.fields) > 1:
45        pdata.pop("number_of_particles", None)
46        num_grids = len(ds.stream_handler.fields)
47        parent_ids = ds.stream_handler.parent_ids
48        num_children = np.zeros(num_grids, dtype="int64")
49        # We're going to do this the slow way
50        mask = np.empty(num_grids, dtype="bool")
51        for i in range(num_grids):
52            np.equal(parent_ids, i, mask)
53            num_children[i] = mask.sum()
54        levels = ds.stream_handler.levels.astype("int64").ravel()
55        grid_tree = GridTree(
56            num_grids,
57            ds.stream_handler.left_edges,
58            ds.stream_handler.right_edges,
59            ds.stream_handler.dimensions,
60            ds.stream_handler.parent_ids,
61            levels,
62            num_children,
63        )
64
65        grid_pdata = []
66        for _ in range(num_grids):
67            grid = {"number_of_particles": 0}
68            grid_pdata.append(grid)
69
70        for ptype in ds.particle_types_raw:
71            if (ptype, "particle_position_x") in pdata:
72                x, y, z = (pdata[ptype, f"particle_position_{ax}"] for ax in "xyz")
73            elif (ptype, "particle_position") in pdata:
74                x, y, z = pdata[ptype, "particle_position"].T
75            else:
76                raise KeyError(
77                    "Cannot decompose particle data without position fields!"
78                )
79            pts = MatchPointsToGrids(grid_tree, len(x), x, y, z)
80            particle_grid_inds = pts.find_points_in_tree()
81            (assigned_particles,) = (particle_grid_inds >= 0).nonzero()
82            num_particles = particle_grid_inds.size
83            num_unassigned = num_particles - assigned_particles.size
84            if num_unassigned > 0:
85                eps = np.finfo(x.dtype).eps
86                s = np.array(
87                    [
88                        [x.min() - eps, x.max() + eps],
89                        [y.min() - eps, y.max() + eps],
90                        [z.min() - eps, z.max() + eps],
91                    ]
92                )
93                sug_bbox = [
94                    [min(bbox[0, 0], s[0, 0]), max(bbox[0, 1], s[0, 1])],
95                    [min(bbox[1, 0], s[1, 0]), max(bbox[1, 1], s[1, 1])],
96                    [min(bbox[2, 0], s[2, 0]), max(bbox[2, 1], s[2, 1])],
97                ]
98                mylog.warning(
99                    "Discarding %s particles (out of %s) that are outside "
100                    "bounding box. Set bbox=%s to avoid this in the future.",
101                    num_unassigned,
102                    num_particles,
103                    sug_bbox,
104                )
105                particle_grid_inds = particle_grid_inds[assigned_particles]
106                x = x[assigned_particles]
107                y = y[assigned_particles]
108                z = z[assigned_particles]
109            idxs = np.argsort(particle_grid_inds)
110            particle_grid_count = np.bincount(
111                particle_grid_inds.astype("intp"), minlength=num_grids
112            )
113            particle_indices = np.zeros(num_grids + 1, dtype="int64")
114            if num_grids > 1:
115                np.add.accumulate(
116                    particle_grid_count.squeeze(), out=particle_indices[1:]
117                )
118            else:
119                particle_indices[1] = particle_grid_count.squeeze()
120            for i, pcount in enumerate(particle_grid_count):
121                grid_pdata[i]["number_of_particles"] += pcount
122                start = particle_indices[i]
123                end = particle_indices[i + 1]
124                for key in pdata.keys():
125                    if key[0] == ptype:
126                        grid_pdata[i][key] = pdata[key][idxs][start:end]
127
128    else:
129        grid_pdata = [pdata]
130
131    for pd, gi in zip(grid_pdata, sorted(ds.stream_handler.fields)):
132        ds.stream_handler.fields[gi].update(pd)
133        ds.stream_handler.particle_types.update(set_particle_types(pd))
134        npart = ds.stream_handler.fields[gi].pop("number_of_particles", 0)
135        ds.stream_handler.particle_count[gi] = npart
136
137
138def process_data(data, grid_dims=None):
139    new_data, field_units = {}, {}
140    for field, val in data.items():
141        # val is a data array
142        if isinstance(val, np.ndarray):
143            # val is a YTArray
144            if hasattr(val, "units"):
145                field_units[field] = val.units
146                new_data[field] = val.copy().d
147            # val is a numpy array
148            else:
149                field_units[field] = ""
150                new_data[field] = val.copy()
151
152        # val is a tuple of (data, units)
153        elif isinstance(val, tuple) and len(val) == 2:
154            try:
155                assert isinstance(field, (str, tuple)), "Field name is not a string!"
156                assert isinstance(val[0], np.ndarray), "Field data is not an ndarray!"
157                assert isinstance(val[1], str), "Unit specification is not a string!"
158                field_units[field] = val[1]
159                new_data[field] = val[0]
160            except AssertionError as e:
161                raise RuntimeError("The data dict appears to be invalid.\n" + str(e))
162
163        # val is a list of data to be turned into an array
164        elif is_sequence(val):
165            field_units[field] = ""
166            new_data[field] = np.asarray(val)
167
168        else:
169            raise RuntimeError(
170                "The data dict appears to be invalid. "
171                "The data dictionary must map from field "
172                "names to (numpy array, unit spec) tuples. "
173            )
174
175    data = new_data
176
177    # At this point, we have arrays for all our fields
178    new_data = {}
179    for field in data:
180        n_shape = len(data[field].shape)
181        if isinstance(field, tuple):
182            new_field = field
183        elif n_shape in (1, 2):
184            new_field = ("io", field)
185        elif n_shape == 3:
186            new_field = ("stream", field)
187        else:
188            raise RuntimeError
189        new_data[new_field] = data[field]
190        field_units[new_field] = field_units.pop(field)
191        known_fields = (
192            StreamFieldInfo.known_particle_fields + StreamFieldInfo.known_other_fields
193        )
194        # We do not want to override any of the known ones, if it's not
195        # overridden here.
196        if (
197            any(f[0] == new_field[1] for f in known_fields)
198            and field_units[new_field] == ""
199        ):
200            field_units.pop(new_field)
201    data = new_data
202    # Sanity checking that all fields have the same dimensions.
203    g_shapes = []
204    p_shapes = defaultdict(list)
205    for field in data:
206        f_shape = data[field].shape
207        n_shape = len(f_shape)
208        if n_shape in (1, 2):
209            p_shapes[field[0]].append((field[1], f_shape[0]))
210        elif n_shape == 3:
211            g_shapes.append((field, f_shape))
212    if len(g_shapes) > 0:
213        g_s = np.array([s[1] for s in g_shapes])
214        if not np.all(g_s == g_s[0]):
215            raise YTInconsistentGridFieldShape(g_shapes)
216        if grid_dims is not None:
217            if not np.all(g_s == grid_dims):
218                raise YTInconsistentGridFieldShapeGridDims(g_shapes, grid_dims)
219    if len(p_shapes) > 0:
220        for ptype, p_shape in p_shapes.items():
221            p_s = np.array([s[1] for s in p_shape])
222            if not np.all(p_s == p_s[0]):
223                raise YTInconsistentParticleFieldShape(ptype, p_shape)
224    # Now that we know the particle fields are consistent, determine the number
225    # of particles.
226    if len(p_shapes) > 0:
227        number_of_particles = np.sum([s[0][1] for s in p_shapes.values()])
228    else:
229        number_of_particles = 0
230    return field_units, data, number_of_particles
231
232
233def set_particle_types(data):
234    particle_types = {}
235    for key in data.keys():
236        if key == "number_of_particles":
237            continue
238        if len(data[key].shape) == 1:
239            particle_types[key] = True
240        else:
241            particle_types[key] = False
242    return particle_types
243