1import uuid
2
3import numpy as np
4
5from yt.fields.derived_field import ValidateSpatial
6from yt.frontends.ytdata.utilities import save_as_dataset
7from yt.funcs import get_output_filename, mylog
8from yt.utilities.tree_container import TreeContainer
9
10from .clump_info_items import clump_info_registry
11from .clump_validators import clump_validator_registry
12from .contour_finder import identify_contours
13
14
15def add_contour_field(ds, contour_key):
16    def _contours(field, data):
17        fd = data.get_field_parameter(f"contour_slices_{contour_key}")
18        vals = data["index", "ones"] * -1
19        if fd is None or fd == 0.0:
20            return vals
21        for sl, v in fd.get(data.id, []):
22            vals[sl] = v
23        return vals
24
25    ds.add_field(
26        ("index", f"contours_{contour_key}"),
27        function=_contours,
28        validators=[ValidateSpatial(0)],
29        take_log=False,
30        display_field=False,
31        sampling_type="cell",
32        units="",
33    )
34
35
36class Clump(TreeContainer):
37    def __init__(
38        self,
39        data,
40        field,
41        parent=None,
42        clump_info=None,
43        validators=None,
44        base=None,
45        contour_key=None,
46        contour_id=None,
47    ):
48        self.data = data
49        self.field = field
50        self.parent = parent
51        self.quantities = data.quantities
52        self.min_val = self.data[field].min()
53        self.max_val = self.data[field].max()
54        self.info = {}
55        self.children = []
56
57        # is this the parent clump?
58        if base is None:
59            base = self
60            self.total_clumps = 0
61
62        if clump_info is None:
63            self.set_default_clump_info()
64        else:
65            self.clump_info = clump_info
66
67        for ci in self.clump_info:
68            ci(self)
69
70        self.base = base
71        self.clump_id = self.base.total_clumps
72        self.base.total_clumps += 1
73        self.contour_key = contour_key
74        self.contour_id = contour_id
75
76        if parent is not None:
77            self.data.parent = self.parent.data
78
79        if validators is None:
80            validators = []
81        self.validators = validators
82        # Return value of validity function.
83        self.valid = None
84
85    _leaves = None
86
87    @property
88    def leaves(self):
89        if self._leaves is not None:
90            return self._leaves
91
92        self._leaves = []
93        for clump in self:
94            if not clump.children:
95                self._leaves.append(clump)
96        return self._leaves
97
98    def add_validator(self, validator, *args, **kwargs):
99        """
100        Add a validating function to determine whether the clump should
101        be kept.
102        """
103        callback = clump_validator_registry.find(validator, *args, **kwargs)
104        self.validators.append(callback)
105        for child in self.children:
106            child.add_validator(validator)
107
108    def add_info_item(self, info_item, *args, **kwargs):
109        "Adds an entry to clump_info list and tells children to do the same."
110
111        callback = clump_info_registry.find(info_item, *args, **kwargs)
112        callback(self)
113        self.clump_info.append(callback)
114        for child in self.children:
115            child.add_info_item(info_item)
116
117    def set_default_clump_info(self):
118        "Defines default entries in the clump_info array."
119
120        # add_info_item is recursive so this function does not need to be.
121        self.clump_info = []
122
123        self.add_info_item("total_cells")
124        self.add_info_item("cell_mass")
125
126        if any("jeans" in f for f in self.data.pf.field_list):
127            self.add_info_item("mass_weighted_jeans_mass")
128            self.add_info_item("volume_weighted_jeans_mass")
129
130        self.add_info_item("max_grid_level")
131
132        if any("number_density" in f for f in self.data.pf.field_list):
133            self.add_info_item("min_number_density")
134            self.add_info_item("max_number_density")
135
136    def clear_clump_info(self):
137        """
138        Clears the clump_info array and passes the instruction to its
139        children.
140        """
141
142        self.clump_info = []
143        for child in self.children:
144            child.clear_clump_info()
145
146    def find_children(self, min_val, max_val=None):
147        if self.children:
148            mylog.info("Wiping out existing children clumps: %d.", len(self.children))
149        self.children = []
150        if max_val is None:
151            max_val = self.max_val
152        nj, cids = identify_contours(self.data, self.field, min_val, max_val)
153        # Here, cids is the set of slices and values, keyed by the
154        # parent_grid_id, that defines the contours.  So we can figure out all
155        # the unique values of the contours by examining the list here.
156        unique_contours = set()
157        for sl_list in cids.values():
158            for _sl, ff in sl_list:
159                unique_contours.update(np.unique(ff))
160        contour_key = uuid.uuid4().hex
161        base_object = getattr(self.data, "base_object", self.data)
162        add_contour_field(base_object.ds, contour_key)
163        for cid in sorted(unique_contours):
164            if cid == -1:
165                continue
166            new_clump = base_object.cut_region(
167                [f"obj['contours_{contour_key}'] == {cid}"],
168                {(f"contour_slices_{contour_key}"): cids},
169            )
170            if new_clump[("index", "ones")].size == 0:
171                # This is to skip possibly duplicate clumps.
172                # Using "ones" here will speed things up.
173                continue
174            self.children.append(
175                Clump(
176                    new_clump,
177                    self.field,
178                    parent=self,
179                    validators=self.validators,
180                    base=self.base,
181                    clump_info=self.clump_info,
182                    contour_key=contour_key,
183                    contour_id=cid,
184                )
185            )
186
187    def __iter__(self):
188        yield self
189        for child in self.children:
190            yield from child
191
192    def save_as_dataset(self, filename=None, fields=None):
193        r"""Export clump tree to a reloadable yt dataset.
194        This function will take a clump object and output a dataset
195        containing the fields given in the ``fields`` list and all info
196        items.  The resulting dataset can be reloaded as a yt dataset.
197
198        Parameters
199        ----------
200        filename : str, optional
201            The name of the file to be written.  If None, the name
202            will be a combination of the original dataset and the clump
203            index.
204        fields : list of strings or tuples, optional
205            If this is supplied, it is the list of fields to be saved to
206            disk.
207
208        Returns
209        -------
210        filename : str
211            The name of the file that has been created.
212
213        Examples
214        --------
215
216        >>> import yt
217        >>> from yt.data_objects.level_sets.api import Clump, find_clumps
218        >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
219        >>> data_source = ds.disk(
220        ...     [0.5, 0.5, 0.5], [0.0, 0.0, 1.0], (8, "kpc"), (1, "kpc")
221        ... )
222        >>> field = ("gas", "density")
223        >>> step = 2.0
224        >>> c_min = 10 ** np.floor(np.log10(data_source[field]).min())
225        >>> c_max = 10 ** np.floor(np.log10(data_source[field]).max() + 1)
226        >>> master_clump = Clump(data_source, field)
227        >>> master_clump.add_info_item("center_of_mass")
228        >>> master_clump.add_validator("min_cells", 20)
229        >>> find_clumps(master_clump, c_min, c_max, step)
230        >>> fn = master_clump.save_as_dataset(
231        ...     fields=[("gas", "density"), ("all", "particle_mass")]
232        ... )
233        >>> new_ds = yt.load(fn)
234        >>> print(ds.tree["clump", "cell_mass"])
235        1296926163.91 Msun
236        >>> print(ds.tree["grid", "density"])
237        [  2.54398434e-26   2.46620353e-26   2.25120154e-26 ...,   1.12879234e-25
238           1.59561490e-25   1.09824903e-24] g/cm**3
239        >>> print(ds.tree["all", "particle_mass"])
240        [  4.25472446e+38   4.25472446e+38   4.25472446e+38 ...,   2.04238266e+38
241           2.04523901e+38   2.04770938e+38] g
242        >>> print(ds.tree.children[0]["clump", "cell_mass"])
243        909636495.312 Msun
244        >>> print(ds.leaves[0]["clump", "cell_mass"])
245        3756566.99809 Msun
246        >>> print(ds.leaves[0]["grid", "density"])
247        [  6.97820274e-24   6.58117370e-24   7.32046082e-24   6.76202430e-24
248           7.41184837e-24   6.76981480e-24   6.94287213e-24   6.56149658e-24
249           6.76584569e-24   6.94073710e-24   7.06713082e-24   7.22556526e-24
250           7.08338898e-24   6.78684331e-24   7.40647040e-24   7.03050456e-24
251           7.12438678e-24   6.56310217e-24   7.23201662e-24   7.17314333e-24] g/cm**3
252
253        """
254
255        ds = self.data.ds
256        keyword = "%s_clump_%d" % (str(ds), self.clump_id)
257        filename = get_output_filename(filename, keyword, ".h5")
258
259        # collect clump info fields
260        clump_info = {ci.name: [] for ci in self.base.clump_info}
261        clump_info.update(
262            {
263                field: []
264                for field in ["clump_id", "parent_id", "contour_key", "contour_id"]
265            }
266        )
267        for clump in self:
268            clump_info["clump_id"].append(clump.clump_id)
269            if clump.parent is None:
270                parent_id = -1
271            else:
272                parent_id = clump.parent.clump_id
273            clump_info["parent_id"].append(parent_id)
274
275            contour_key = clump.contour_key
276            if contour_key is None:
277                contour_key = -1
278            clump_info["contour_key"].append(contour_key)
279            contour_id = clump.contour_id
280            if contour_id is None:
281                contour_id = -1
282            clump_info["contour_id"].append(contour_id)
283
284            for ci in self.base.clump_info:
285                clump_info[ci.name].append(clump.info[ci.name][1])
286        for ci in clump_info:
287            if hasattr(clump_info[ci][0], "units"):
288                clump_info[ci] = ds.arr(clump_info[ci])
289            else:
290                clump_info[ci] = np.array(clump_info[ci])
291
292        ftypes = {ci: "clump" for ci in clump_info}
293
294        # collect data fields
295        if fields is not None:
296            contour_fields = [
297                ("index", f"contours_{ckey}")
298                for ckey in np.unique(clump_info["contour_key"])
299                if str(ckey) != "-1"
300            ]
301
302            ptypes = []
303            field_data = {}
304            need_grid_positions = False
305            for f in self.base.data._determine_fields(fields) + contour_fields:
306                if ds.field_info[f].sampling_type == "particle":
307                    if f[0] not in ptypes:
308                        ptypes.append(f[0])
309                    ftypes[f] = f[0]
310                else:
311                    need_grid_positions = True
312                    if f[1] in ("x", "y", "z", "dx", "dy", "dz"):
313                        # skip 'xyz' if a user passes that in because they
314                        # will be added to ftypes below
315                        continue
316                    ftypes[f] = "grid"
317                field_data[f] = self.base[f]
318
319            if len(ptypes) > 0:
320                for ax in "xyz":
321                    for ptype in ptypes:
322                        p_field = (ptype, f"particle_position_{ax}")
323                        if p_field in ds.field_info and p_field not in field_data:
324                            ftypes[p_field] = p_field[0]
325                            field_data[p_field] = self.base[p_field]
326
327                for clump in self:
328                    if clump.contour_key is None:
329                        continue
330                    for ptype in ptypes:
331                        cfield = (ptype, f"contours_{clump.contour_key}")
332                        if cfield not in field_data:
333                            field_data[cfield] = clump.data._part_ind(ptype).astype(
334                                np.int64
335                            )
336                            ftypes[cfield] = ptype
337                        field_data[cfield][
338                            clump.data._part_ind(ptype)
339                        ] = clump.contour_id
340
341            if need_grid_positions:
342                for ax in "xyz":
343                    g_field = ("index", ax)
344                    if g_field in ds.field_info and g_field not in field_data:
345                        field_data[g_field] = self.base[g_field]
346                        ftypes[g_field] = "grid"
347                    g_field = ("index", "d" + ax)
348                    if g_field in ds.field_info and g_field not in field_data:
349                        ftypes[g_field] = "grid"
350                        field_data[g_field] = self.base[g_field]
351
352            if self.contour_key is not None:
353                cfilters = {}
354                for field in field_data:
355                    if ftypes[field] == "grid":
356                        ftype = "index"
357                    else:
358                        ftype = field[0]
359                    cfield = (ftype, f"contours_{self.contour_key}")
360                    if cfield not in cfilters:
361                        cfilters[cfield] = field_data[cfield] == self.contour_id
362                    field_data[field] = field_data[field][cfilters[cfield]]
363
364            clump_info.update(field_data)
365        extra_attrs = {"data_type": "yt_clump_tree", "container_type": "yt_clump_tree"}
366        save_as_dataset(
367            ds, filename, clump_info, field_types=ftypes, extra_attrs=extra_attrs
368        )
369
370        return filename
371
372    def pass_down(self, operation):
373        """
374        Performs an operation on a clump with an exec and passes the
375        instruction down to clump children.
376        """
377
378        # Call if callable, otherwise do an exec.
379        if callable(operation):
380            operation()
381        else:
382            exec(operation)
383
384        for child in self.children:
385            child.pass_down(operation)
386
387    def _validate(self):
388        "Apply all user specified validator functions."
389
390        # Only call functions if not done already.
391        if self.valid is not None:
392            return self.valid
393
394        self.valid = True
395        for validator in self.validators:
396            self.valid &= validator(self)
397            if not self.valid:
398                break
399
400        return self.valid
401
402    def __reduce__(self):
403        raise RuntimeError(
404            "Pickling Clump instances is not supported. Please use "
405            "Clump.save_as_dataset instead"
406        )
407
408    def __getitem__(self, request):
409        return self.data[request]
410
411
412def find_clumps(clump, min_val, max_val, d_clump):
413    mylog.info("Finding clumps: min: %e, max: %e, step: %f", min_val, max_val, d_clump)
414    if min_val >= max_val:
415        return
416    clump.find_children(min_val, max_val=max_val)
417
418    if len(clump.children) == 1:
419        find_clumps(clump, min_val * d_clump, max_val, d_clump)
420
421    elif len(clump.children) > 0:
422        these_children = []
423        mylog.info("Investigating %d children.", len(clump.children))
424        for child in clump.children:
425            find_clumps(child, min_val * d_clump, max_val, d_clump)
426            if len(child.children) > 0:
427                these_children.append(child)
428            elif child._validate():
429                these_children.append(child)
430            else:
431                mylog.info(
432                    "Eliminating invalid, childless clump with %d cells.",
433                    len(child.data[("index", "ones")]),
434                )
435        if len(these_children) > 1:
436            mylog.info(
437                "%d of %d children survived.", len(these_children), len(clump.children)
438            )
439            clump.children = these_children
440        elif len(these_children) == 1:
441            mylog.info(
442                "%d of %d children survived, linking its children to parent.",
443                len(these_children),
444                len(clump.children),
445            )
446            clump.children = these_children[0].children
447            for child in clump.children:
448                child.parent = clump
449                child.data.parent = clump.data
450        else:
451            mylog.info(
452                "%d of %d children survived, erasing children.",
453                len(these_children),
454                len(clump.children),
455            )
456            clump.children = []
457