1import numpy as np
2
3from yt.fields.derived_field import ValidateParameter, ValidateSpatial
4from yt.units.yt_array import uconcatenate, ucross
5from yt.utilities.lib.misc_utilities import (
6    obtain_position_vector,
7    obtain_relative_velocity_vector,
8)
9from yt.utilities.math_utils import (
10    get_cyl_r,
11    get_cyl_r_component,
12    get_cyl_theta,
13    get_cyl_theta_component,
14    get_cyl_z,
15    get_cyl_z_component,
16    get_sph_phi,
17    get_sph_phi_component,
18    get_sph_r_component,
19    get_sph_theta,
20    get_sph_theta_component,
21    modify_reference_frame,
22)
23
24from .field_functions import get_radius
25from .vector_operations import create_magnitude_field
26
27sph_whitelist_fields = (
28    "density",
29    "temperature",
30    "metallicity",
31    "thermal_energy",
32    "smoothing_length",
33    "H_fraction",
34    "He_fraction",
35    "C_fraction",
36    "Ca_fraction",
37    "N_fraction",
38    "O_fraction",
39    "S_fraction",
40    "Ne_fraction",
41    "Mg_fraction",
42    "Si_fraction",
43    "Fe_fraction",
44    "H_density",
45    "He_density",
46    "C_density",
47    "Ca_density",
48    "N_density",
49    "O_density",
50    "S_density",
51    "Ne_density",
52    "Mg_density",
53    "Si_density",
54    "Fe_density",
55)
56
57
58def _field_concat(fname):
59    def _AllFields(field, data):
60        v = []
61        for ptype in data.ds.particle_types:
62            data.ds._last_freq = (ptype, None)
63            if ptype == "all" or ptype in data.ds.known_filters:
64                continue
65            v.append(data[ptype, fname].copy())
66        rv = uconcatenate(v, axis=0)
67        return rv
68
69    return _AllFields
70
71
72def _field_concat_slice(fname, axi):
73    def _AllFields(field, data):
74        v = []
75        for ptype in data.ds.particle_types:
76            data.ds._last_freq = (ptype, None)
77            if ptype == "all" or ptype in data.ds.known_filters:
78                continue
79            v.append(data[ptype, fname][:, axi])
80        rv = uconcatenate(v, axis=0)
81        return rv
82
83    return _AllFields
84
85
86def particle_deposition_functions(ptype, coord_name, mass_name, registry):
87    unit_system = registry.ds.unit_system
88    orig = set(registry.keys())
89    ptype_dn = ptype.replace("_", " ").title()
90
91    def particle_count(field, data):
92        pos = data[ptype, coord_name]
93        d = data.deposit(pos, method="count")
94        return data.apply_units(d, field.units)
95
96    registry.add_field(
97        ("deposit", f"{ptype}_count"),
98        sampling_type="cell",
99        function=particle_count,
100        validators=[ValidateSpatial()],
101        units="",
102        display_name=r"\mathrm{%s Count}" % ptype_dn,
103    )
104
105    def particle_mass(field, data):
106        pos = data[ptype, coord_name]
107        pmass = data[ptype, mass_name]
108        pmass.convert_to_units(field.units)
109        d = data.deposit(pos, [pmass], method="sum")
110        return data.apply_units(d, field.units)
111
112    registry.add_field(
113        ("deposit", f"{ptype}_mass"),
114        sampling_type="cell",
115        function=particle_mass,
116        validators=[ValidateSpatial()],
117        display_name=r"\mathrm{%s Mass}" % ptype_dn,
118        units=unit_system["mass"],
119    )
120
121    def particle_density(field, data):
122        pos = data[ptype, coord_name]
123        pos.convert_to_units("code_length")
124        mass = data[ptype, mass_name]
125        mass.convert_to_units("code_mass")
126        d = data.deposit(pos, [mass], method="sum")
127        d = data.ds.arr(d, "code_mass")
128        d /= data["index", "cell_volume"]
129        return d
130
131    registry.add_field(
132        ("deposit", f"{ptype}_density"),
133        sampling_type="cell",
134        function=particle_density,
135        validators=[ValidateSpatial()],
136        display_name=r"\mathrm{%s Density}" % ptype_dn,
137        units=unit_system["density"],
138    )
139
140    def particle_cic(field, data):
141        pos = data[ptype, coord_name]
142        d = data.deposit(pos, [data[ptype, mass_name]], method="cic")
143        d = data.apply_units(d, data[ptype, mass_name].units)
144        d /= data["index", "cell_volume"]
145        return d
146
147    registry.add_field(
148        ("deposit", f"{ptype}_cic"),
149        sampling_type="cell",
150        function=particle_cic,
151        validators=[ValidateSpatial()],
152        display_name=r"\mathrm{%s CIC Density}" % ptype_dn,
153        units=unit_system["density"],
154    )
155
156    def _get_density_weighted_deposit_field(fname, units, method):
157        def _deposit_field(field, data):
158            """
159            Create a grid field for particle quantities weighted by particle
160            mass, using cloud-in-cell deposit.
161            """
162            pos = data[ptype, "particle_position"]
163            # Get back into density
164            pden = data[ptype, "particle_mass"]
165            top = data.deposit(pos, [pden * data[(ptype, fname)]], method=method)
166            bottom = data.deposit(pos, [pden], method=method)
167            top[bottom == 0] = 0.0
168            bnz = bottom.nonzero()
169            top[bnz] /= bottom[bnz]
170            d = data.ds.arr(top, units=units)
171            return d
172
173        return _deposit_field
174
175    for ax in "xyz":
176        for method, name in zip(("cic", "sum"), ("cic", "nn")):
177            function = _get_density_weighted_deposit_field(
178                f"particle_velocity_{ax}", "code_velocity", method
179            )
180            registry.add_field(
181                ("deposit", ("%s_" + name + "_velocity_%s") % (ptype, ax)),
182                sampling_type="cell",
183                function=function,
184                units=unit_system["velocity"],
185                take_log=False,
186                validators=[ValidateSpatial(0)],
187            )
188
189    for method, name in zip(("cic", "sum"), ("cic", "nn")):
190        function = _get_density_weighted_deposit_field("age", "code_time", method)
191        registry.add_field(
192            ("deposit", ("%s_" + name + "_age") % (ptype)),
193            sampling_type="cell",
194            function=function,
195            units=unit_system["time"],
196            take_log=False,
197            validators=[ValidateSpatial(0)],
198        )
199
200    # Now some translation functions.
201
202    def particle_ones(field, data):
203        v = np.ones(data[ptype, coord_name].shape[0], dtype="float64")
204        return data.apply_units(v, field.units)
205
206    registry.add_field(
207        (ptype, "particle_ones"),
208        sampling_type="particle",
209        function=particle_ones,
210        units="",
211        display_name=r"Particle Count",
212    )
213
214    def particle_mesh_ids(field, data):
215        pos = data[ptype, coord_name]
216        ids = np.zeros(pos.shape[0], dtype="float64") - 1
217        # This is float64 in name only.  It will be properly cast inside the
218        # deposit operation.
219        # _ids = ids.view("float64")
220        data.deposit(pos, [ids], method="mesh_id")
221        return data.apply_units(ids, "")
222
223    registry.add_field(
224        (ptype, "mesh_id"),
225        sampling_type="particle",
226        function=particle_mesh_ids,
227        validators=[ValidateSpatial()],
228        units="",
229    )
230
231    return list(set(registry.keys()).difference(orig))
232
233
234def particle_scalar_functions(ptype, coord_name, vel_name, registry):
235
236    # Now we have to set up the various velocity and coordinate things.  In the
237    # future, we'll actually invert this and use the 3-component items
238    # elsewhere, and stop using these.
239
240    # Note that we pass in _ptype here so that it's defined inside the closure.
241
242    def _get_coord_funcs(axi, _ptype):
243        def _particle_velocity(field, data):
244            return data[_ptype, vel_name][:, axi]
245
246        def _particle_position(field, data):
247            return data[_ptype, coord_name][:, axi]
248
249        return _particle_velocity, _particle_position
250
251    for axi, ax in enumerate("xyz"):
252        v, p = _get_coord_funcs(axi, ptype)
253        registry.add_field(
254            (ptype, f"particle_velocity_{ax}"),
255            sampling_type="particle",
256            function=v,
257            units="code_velocity",
258        )
259        registry.add_field(
260            (ptype, f"particle_position_{ax}"),
261            sampling_type="particle",
262            function=p,
263            units="code_length",
264        )
265
266
267def particle_vector_functions(ptype, coord_names, vel_names, registry):
268
269    unit_system = registry.ds.unit_system
270
271    # This will column_stack a set of scalars to create vector fields.
272
273    def _get_vec_func(_ptype, names):
274        def particle_vectors(field, data):
275            v = [data[_ptype, name].in_units(field.units) for name in names]
276            return data.ds.arr(np.column_stack(v), v[0].units)
277
278        return particle_vectors
279
280    registry.add_field(
281        (ptype, "particle_position"),
282        sampling_type="particle",
283        function=_get_vec_func(ptype, coord_names),
284        units="code_length",
285    )
286
287    registry.add_field(
288        (ptype, "particle_velocity"),
289        sampling_type="particle",
290        function=_get_vec_func(ptype, vel_names),
291        units=unit_system["velocity"],
292    )
293
294
295def get_angular_momentum_components(ptype, data, spos, svel):
296    if data.has_field_parameter("normal"):
297        normal = data.get_field_parameter("normal")
298    else:
299        normal = data.ds.arr(
300            [0.0, 0.0, 1.0], "code_length"
301        )  # default to simulation axis
302    pos = data.ds.arr([data[ptype, spos % ax] for ax in "xyz"]).T
303    vel = data.ds.arr([data[ptype, f"relative_{svel % ax}"] for ax in "xyz"]).T
304    return pos, vel, normal
305
306
307def standard_particle_fields(
308    registry, ptype, spos="particle_position_%s", svel="particle_velocity_%s"
309):
310    unit_system = registry.ds.unit_system
311
312    def _particle_velocity_magnitude(field, data):
313        """M{|v|}"""
314        return np.sqrt(
315            data[ptype, f"relative_{svel % 'x'}"] ** 2
316            + data[ptype, f"relative_{svel % 'y'}"] ** 2
317            + data[ptype, f"relative_{svel % 'z'}"] ** 2
318        )
319
320    registry.add_field(
321        (ptype, "particle_velocity_magnitude"),
322        sampling_type="particle",
323        function=_particle_velocity_magnitude,
324        take_log=False,
325        units=unit_system["velocity"],
326    )
327
328    def _particle_specific_angular_momentum(field, data):
329        """Calculate the angular of a particle velocity.
330
331        Returns a vector for each particle.
332        """
333        center = data.get_field_parameter("center")
334        pos, vel, normal = get_angular_momentum_components(ptype, data, spos, svel)
335        L, r_vec, v_vec = modify_reference_frame(center, normal, P=pos, V=vel)
336        # adding in the unit registry allows us to have a reference to the
337        # dataset and thus we will always get the correct units after applying
338        # the cross product.
339        return ucross(r_vec, v_vec, registry=data.ds.unit_registry)
340
341    registry.add_field(
342        (ptype, "particle_specific_angular_momentum"),
343        sampling_type="particle",
344        function=_particle_specific_angular_momentum,
345        units=unit_system["specific_angular_momentum"],
346        validators=[ValidateParameter("center")],
347    )
348
349    def _get_spec_ang_mom_comp(axi, ax, _ptype):
350        def _particle_specific_angular_momentum_component(field, data):
351            return data[_ptype, "particle_specific_angular_momentum"][:, axi]
352
353        def _particle_angular_momentum_component(field, data):
354            return (
355                data[_ptype, "particle_mass"]
356                * data[ptype, f"particle_specific_angular_momentum_{ax}"]
357            )
358
359        return (
360            _particle_specific_angular_momentum_component,
361            _particle_angular_momentum_component,
362        )
363
364    for axi, ax in enumerate("xyz"):
365        f, v = _get_spec_ang_mom_comp(axi, ax, ptype)
366        registry.add_field(
367            (ptype, f"particle_specific_angular_momentum_{ax}"),
368            sampling_type="particle",
369            function=f,
370            units=unit_system["specific_angular_momentum"],
371            validators=[ValidateParameter("center")],
372        )
373        registry.add_field(
374            (ptype, f"particle_angular_momentum_{ax}"),
375            sampling_type="particle",
376            function=v,
377            units=unit_system["angular_momentum"],
378            validators=[ValidateParameter("center")],
379        )
380
381    def _particle_angular_momentum(field, data):
382        am = (
383            data[ptype, "particle_mass"]
384            * data[ptype, "particle_specific_angular_momentum"].T
385        )
386        return am.T
387
388    registry.add_field(
389        (ptype, "particle_angular_momentum"),
390        sampling_type="particle",
391        function=_particle_angular_momentum,
392        units=unit_system["angular_momentum"],
393        validators=[ValidateParameter("center")],
394    )
395
396    create_magnitude_field(
397        registry,
398        "particle_angular_momentum",
399        unit_system["angular_momentum"],
400        sampling_type="particle",
401        ftype=ptype,
402    )
403
404    def _particle_radius(field, data):
405        """The spherical radius component of the particle positions
406
407        Relative to the coordinate system defined by the *normal* vector,
408        and *center* field parameters.
409        """
410        return get_radius(data, "particle_position_", field.name[0])
411
412    registry.add_field(
413        (ptype, "particle_radius"),
414        sampling_type="particle",
415        function=_particle_radius,
416        units=unit_system["length"],
417        validators=[ValidateParameter("center")],
418    )
419
420    def _relative_particle_position(field, data):
421        """The cartesian particle positions in a rotated reference frame
422
423        Relative to the coordinate system defined by *center* field parameter.
424
425        Note that the orientation of the x and y axes are arbitrary.
426        """
427        field_names = [(ptype, f"particle_position_{ax}") for ax in "xyz"]
428        return obtain_position_vector(data, field_names=field_names).T
429
430    registry.add_field(
431        (ptype, "relative_particle_position"),
432        sampling_type="particle",
433        function=_relative_particle_position,
434        units=unit_system["length"],
435        validators=[ValidateParameter("normal"), ValidateParameter("center")],
436    )
437
438    def _relative_particle_velocity(field, data):
439        """The vector particle velocities in an arbitrary coordinate system
440
441        Relative to the coordinate system defined by the *bulk_velocity*
442        vector field parameter.
443
444        Note that the orientation of the x and y axes are arbitrary.
445        """
446        field_names = [(ptype, f"particle_velocity_{ax}") for ax in "xyz"]
447        return obtain_relative_velocity_vector(data, field_names=field_names).T
448
449    registry.add_field(
450        (ptype, "relative_particle_velocity"),
451        sampling_type="particle",
452        function=_relative_particle_velocity,
453        units=unit_system["velocity"],
454        validators=[ValidateParameter("normal"), ValidateParameter("center")],
455    )
456
457    def _get_coord_funcs_relative(axi, _ptype):
458        def _particle_pos_rel(field, data):
459            return data[_ptype, "relative_particle_position"][:, axi]
460
461        def _particle_vel_rel(field, data):
462            return data[_ptype, "relative_particle_velocity"][:, axi]
463
464        return _particle_vel_rel, _particle_pos_rel
465
466    for axi, ax in enumerate("xyz"):
467        v, p = _get_coord_funcs_relative(axi, ptype)
468        registry.add_field(
469            (ptype, f"particle_velocity_relative_{ax}"),
470            sampling_type="particle",
471            function=v,
472            units="code_velocity",
473        )
474        registry.add_field(
475            (ptype, f"particle_position_relative_{ax}"),
476            sampling_type="particle",
477            function=p,
478            units="code_length",
479        )
480        registry.add_field(
481            (ptype, f"relative_particle_velocity_{ax}"),
482            sampling_type="particle",
483            function=v,
484            units="code_velocity",
485        )
486        registry.add_field(
487            (ptype, f"relative_particle_position_{ax}"),
488            sampling_type="particle",
489            function=p,
490            units="code_length",
491        )
492
493    # this is just particle radius but we add it with an alias for the sake of
494    # consistent naming
495    registry.add_field(
496        (ptype, "particle_position_spherical_radius"),
497        sampling_type="particle",
498        function=_particle_radius,
499        units=unit_system["length"],
500        validators=[ValidateParameter("normal"), ValidateParameter("center")],
501    )
502
503    registry.alias(
504        (ptype, "particle_spherical_position_radius"),
505        (ptype, "particle_position_spherical_radius"),
506        deprecate=("4.0.0", "4.1.0"),
507    )
508
509    def _particle_position_spherical_theta(field, data):
510        """The spherical theta coordinate of the particle positions.
511
512        Relative to the coordinate system defined by the *normal* vector
513        and *center* field parameters.
514        """
515        normal = data.get_field_parameter("normal")
516        pos = data[(ptype, "relative_particle_position")].T
517        return data.ds.arr(get_sph_theta(pos, normal), "")
518
519    registry.add_field(
520        (ptype, "particle_position_spherical_theta"),
521        sampling_type="particle",
522        function=_particle_position_spherical_theta,
523        units="",
524        validators=[ValidateParameter("center"), ValidateParameter("normal")],
525    )
526
527    registry.alias(
528        (ptype, "particle_spherical_position_theta"),
529        (ptype, "particle_position_spherical_theta"),
530        deprecate=("4.0.0", "4.1.0"),
531    )
532
533    def _particle_position_spherical_phi(field, data):
534        """The spherical phi component of the particle positions
535
536        Relative to the coordinate system defined by the *normal* vector
537        and *center* field parameters.
538        """
539        normal = data.get_field_parameter("normal")
540        pos = data[(ptype, "relative_particle_position")].T
541        return data.ds.arr(get_sph_phi(pos, normal), "")
542
543    registry.add_field(
544        (ptype, "particle_position_spherical_phi"),
545        sampling_type="particle",
546        function=_particle_position_spherical_phi,
547        units="",
548        validators=[ValidateParameter("normal"), ValidateParameter("center")],
549    )
550
551    registry.alias(
552        (ptype, "particle_spherical_position_phi"),
553        (ptype, "particle_position_spherical_phi"),
554        deprecate=("4.0.0", "4.1.0"),
555    )
556
557    def _particle_velocity_spherical_radius(field, data):
558        """The spherical radius component of the particle velocities in an
559         arbitrary coordinate system
560
561        Relative to the coordinate system defined by the *normal* vector,
562        *bulk_velocity* vector and *center* field parameters.
563        """
564        normal = data.get_field_parameter("normal")
565        pos = data[(ptype, "relative_particle_position")].T
566        vel = data[(ptype, "relative_particle_velocity")].T
567        theta = get_sph_theta(pos, normal)
568        phi = get_sph_phi(pos, normal)
569        sphr = get_sph_r_component(vel, theta, phi, normal)
570        return sphr
571
572    registry.add_field(
573        (ptype, "particle_velocity_spherical_radius"),
574        sampling_type="particle",
575        function=_particle_velocity_spherical_radius,
576        units=unit_system["velocity"],
577        validators=[ValidateParameter("normal"), ValidateParameter("center")],
578    )
579
580    registry.alias(
581        (ptype, "particle_spherical_velocity_radius"),
582        (ptype, "particle_velocity_spherical_radius"),
583        deprecate=("4.0.0", "4.1.0"),
584    )
585
586    registry.alias(
587        (ptype, "particle_radial_velocity"),
588        (ptype, "particle_velocity_spherical_radius"),
589    )
590
591    def _particle_velocity_spherical_theta(field, data):
592        """The spherical theta component of the particle velocities in an
593         arbitrary coordinate system
594
595        Relative to the coordinate system defined by the *normal* vector,
596        *bulk_velocity* vector and *center* field parameters.
597        """
598        normal = data.get_field_parameter("normal")
599        pos = data[(ptype, "relative_particle_position")].T
600        vel = data[(ptype, "relative_particle_velocity")].T
601        theta = get_sph_theta(pos, normal)
602        phi = get_sph_phi(pos, normal)
603        spht = get_sph_theta_component(vel, theta, phi, normal)
604        return spht
605
606    registry.add_field(
607        (ptype, "particle_velocity_spherical_theta"),
608        sampling_type="particle",
609        function=_particle_velocity_spherical_theta,
610        units=unit_system["velocity"],
611        validators=[ValidateParameter("normal"), ValidateParameter("center")],
612    )
613
614    registry.alias(
615        (ptype, "particle_spherical_velocity_theta"),
616        (ptype, "particle_velocity_spherical_theta"),
617        deprecate=("4.0.0", "4.1.0"),
618    )
619
620    def _particle_velocity_spherical_phi(field, data):
621        """The spherical phi component of the particle velocities
622
623        Relative to the coordinate system defined by the *normal* vector,
624        *bulk_velocity* vector and *center* field parameters.
625        """
626        normal = data.get_field_parameter("normal")
627        pos = data[(ptype, "relative_particle_position")].T
628        vel = data[(ptype, "relative_particle_velocity")].T
629        phi = get_sph_phi(pos, normal)
630        sphp = get_sph_phi_component(vel, phi, normal)
631        return sphp
632
633    registry.add_field(
634        (ptype, "particle_velocity_spherical_phi"),
635        sampling_type="particle",
636        function=_particle_velocity_spherical_phi,
637        units=unit_system["velocity"],
638        validators=[ValidateParameter("normal"), ValidateParameter("center")],
639    )
640
641    registry.alias(
642        (ptype, "particle_spherical_velocity_phi"),
643        (ptype, "particle_velocity_spherical_phi"),
644        deprecate=("4.0.0", "4.1.0"),
645    )
646
647    def _particle_position_cylindrical_radius(field, data):
648        """The cylindrical radius component of the particle positions
649
650        Relative to the coordinate system defined by the *normal* vector
651        and *center* field parameters.
652        """
653        normal = data.get_field_parameter("normal")
654        pos = data[(ptype, "relative_particle_position")].T
655        pos.convert_to_units("code_length")
656        return data.ds.arr(get_cyl_r(pos, normal), "code_length")
657
658    registry.add_field(
659        (ptype, "particle_position_cylindrical_radius"),
660        sampling_type="particle",
661        function=_particle_position_cylindrical_radius,
662        units=unit_system["length"],
663        validators=[ValidateParameter("normal"), ValidateParameter("center")],
664    )
665
666    def _particle_position_cylindrical_theta(field, data):
667        """The cylindrical theta component of the particle positions
668
669        Relative to the coordinate system defined by the *normal* vector
670        and *center* field parameters.
671        """
672        normal = data.get_field_parameter("normal")
673        pos = data[(ptype, "relative_particle_position")].T
674        return data.ds.arr(get_cyl_theta(pos, normal), "")
675
676    registry.add_field(
677        (ptype, "particle_position_cylindrical_theta"),
678        sampling_type="particle",
679        function=_particle_position_cylindrical_theta,
680        units="",
681        validators=[ValidateParameter("center"), ValidateParameter("normal")],
682    )
683
684    def _particle_position_cylindrical_z(field, data):
685        """The cylindrical z component of the particle positions
686
687        Relative to the coordinate system defined by the *normal* vector
688        and *center* field parameters.
689        """
690        normal = data.get_field_parameter("normal")
691        pos = data[(ptype, "relative_particle_position")].T
692        pos.convert_to_units("code_length")
693        return data.ds.arr(get_cyl_z(pos, normal), "code_length")
694
695    registry.add_field(
696        (ptype, "particle_position_cylindrical_z"),
697        sampling_type="particle",
698        function=_particle_position_cylindrical_z,
699        units=unit_system["length"],
700        validators=[ValidateParameter("normal"), ValidateParameter("center")],
701    )
702
703    def _particle_velocity_cylindrical_radius(field, data):
704        """The cylindrical radius component of the particle velocities
705
706        Relative to the coordinate system defined by the *normal* vector,
707        *bulk_velocity* vector and *center* field parameters.
708        """
709        normal = data.get_field_parameter("normal")
710        pos = data[(ptype, "relative_particle_position")].T
711        vel = data[(ptype, "relative_particle_velocity")].T
712        theta = get_cyl_theta(pos, normal)
713        cylr = get_cyl_r_component(vel, theta, normal)
714        return cylr
715
716    registry.add_field(
717        (ptype, "particle_velocity_cylindrical_radius"),
718        sampling_type="particle",
719        function=_particle_velocity_cylindrical_radius,
720        units=unit_system["velocity"],
721        validators=[ValidateParameter("normal"), ValidateParameter("center")],
722    )
723
724    def _particle_velocity_cylindrical_theta(field, data):
725        """The cylindrical theta component of the particle velocities
726
727        Relative to the coordinate system defined by the *normal* vector,
728        *bulk_velocity* vector and *center* field parameters.
729        """
730        normal = data.get_field_parameter("normal")
731        pos = data[(ptype, "relative_particle_position")].T
732        vel = data[(ptype, "relative_particle_velocity")].T
733        theta = get_cyl_theta(pos, normal)
734        cylt = get_cyl_theta_component(vel, theta, normal)
735        return cylt
736
737    registry.add_field(
738        (ptype, "particle_velocity_cylindrical_theta"),
739        sampling_type="particle",
740        function=_particle_velocity_cylindrical_theta,
741        units=unit_system["velocity"],
742        validators=[ValidateParameter("normal"), ValidateParameter("center")],
743    )
744
745    registry.alias(
746        (ptype, "particle_cylindrical_velocity_theta"),
747        (ptype, "particle_velocity_cylindrical_theta"),
748        deprecate=("4.0.0", "4.1.0"),
749    )
750
751    def _particle_velocity_cylindrical_z(field, data):
752        """The cylindrical z component of the particle velocities
753
754        Relative to the coordinate system defined by the *normal* vector,
755        *bulk_velocity* vector and *center* field parameters.
756        """
757        normal = data.get_field_parameter("normal")
758        vel = data[(ptype, "relative_particle_velocity")].T
759        cylz = get_cyl_z_component(vel, normal)
760        return cylz
761
762    registry.add_field(
763        (ptype, "particle_velocity_cylindrical_z"),
764        sampling_type="particle",
765        function=_particle_velocity_cylindrical_z,
766        units=unit_system["velocity"],
767        validators=[ValidateParameter("normal"), ValidateParameter("center")],
768    )
769
770    registry.alias(
771        (ptype, "particle_cylindrical_velocity_z"),
772        (ptype, "particle_velocity_cylindrical_z"),
773        deprecate=("4.0.0", "4.1.0"),
774    )
775
776
777def add_particle_average(registry, ptype, field_name, weight=None, density=True):
778    if weight is None:
779        weight = (ptype, "particle_mass")
780    field_units = registry[ptype, field_name].units
781
782    def _pfunc_avg(field, data):
783        pos = data[ptype, "particle_position"]
784        f = data[ptype, field_name]
785        wf = data[ptype, weight]
786        f *= wf
787        v = data.deposit(pos, [f], method="sum")
788        w = data.deposit(pos, [wf], method="sum")
789        v /= w
790        if density:
791            v /= data["index", "cell_volume"]
792        v[np.isnan(v)] = 0.0
793        return v
794
795    fn = ("deposit", f"{ptype}_avg_{field_name}")
796    registry.add_field(
797        fn,
798        sampling_type="cell",
799        function=_pfunc_avg,
800        validators=[ValidateSpatial(0)],
801        units=field_units,
802    )
803    return fn
804
805
806def add_volume_weighted_smoothed_field(
807    ptype,
808    coord_name,
809    mass_name,
810    smoothing_length_name,
811    density_name,
812    smoothed_field,
813    registry,
814    nneighbors=64,
815    kernel_name="cubic",
816):
817    from yt._maintenance.deprecation import issue_deprecation_warning
818
819    issue_deprecation_warning(
820        "This function is deprecated. "
821        "Since yt-4.0, it's no longer necessary to add a field specifically for "
822        "smoothing, because the global octree is removed. The old behavior of "
823        "interpolating onto a grid structure can be recovered through data objects "
824        "like ds.arbitrary_grid, ds.covering_grid, and most closely ds.octree. The "
825        "visualization machinery now treats SPH fields properly by smoothing onto "
826        "pixel locations. See this page to learn more: "
827        "https://yt-project.org/doc/yt4differences.html",
828        since="4.0.0",
829        removal="4.1.0",
830    )
831
832
833def add_nearest_neighbor_field(ptype, coord_name, registry, nneighbors=64):
834    field_name = (ptype, f"nearest_neighbor_distance_{nneighbors}")
835
836    def _nth_neighbor(field, data):
837        pos = data[ptype, coord_name]
838        pos.convert_to_units("code_length")
839        distances = 0.0 * pos[:, 0]
840        data.particle_operation(
841            pos, [distances], method="nth_neighbor", nneighbors=nneighbors
842        )
843        # Now some quick unit conversions.
844        return distances
845
846    registry.add_field(
847        field_name,
848        sampling_type="particle",
849        function=_nth_neighbor,
850        validators=[ValidateSpatial(0)],
851        units="code_length",
852    )
853    return [field_name]
854
855
856def add_nearest_neighbor_value_field(ptype, coord_name, sampled_field, registry):
857    """
858    This adds a nearest-neighbor field, where values on the mesh are assigned
859    based on the nearest particle value found.  This is useful, for instance,
860    with voronoi-tesselations.
861    """
862    field_name = ("deposit", f"{ptype}_nearest_{sampled_field}")
863    field_units = registry[ptype, sampled_field].units
864    unit_system = registry.ds.unit_system
865
866    def _nearest_value(field, data):
867        pos = data[ptype, coord_name]
868        pos = pos.convert_to_units("code_length")
869        value = data[ptype, sampled_field].in_base(unit_system.name)
870        rv = data.smooth(
871            pos, [value], method="nearest", create_octree=True, nneighbors=1
872        )
873        rv = data.apply_units(rv, field_units)
874        return rv
875
876    registry.add_field(
877        field_name,
878        sampling_type="cell",
879        function=_nearest_value,
880        validators=[ValidateSpatial(0)],
881        units=field_units,
882    )
883    return [field_name]
884
885
886def add_union_field(registry, ptype, field_name, units):
887    """
888    Create a field that is the concatenation of multiple particle types.
889    This allows us to create fields for particle unions using alias names.
890    """
891
892    def _cat_field(field, data):
893        return uconcatenate(
894            [data[dep_type, field_name] for dep_type in data.ds.particle_types_raw]
895        )
896
897    registry.add_field(
898        (ptype, field_name), sampling_type="particle", function=_cat_field, units=units
899    )
900