1import numpy as np
2
3from yt.funcs import is_sequence, just_one
4from yt.geometry.geometry_handler import is_curvilinear
5from yt.utilities.lib.misc_utilities import obtain_relative_velocity_vector
6from yt.utilities.math_utils import (
7    get_cyl_r_component,
8    get_cyl_theta_component,
9    get_cyl_z_component,
10    get_sph_phi_component,
11    get_sph_r_component,
12    get_sph_theta_component,
13)
14
15from .derived_field import NeedsParameter, ValidateParameter, ValidateSpatial
16
17
18def get_bulk(data, basename, unit):
19    if data.has_field_parameter(f"bulk_{basename}"):
20        bulk = data.get_field_parameter(f"bulk_{basename}")
21    else:
22        bulk = [0, 0, 0] * unit
23    return bulk
24
25
26def create_magnitude_field(
27    registry,
28    basename,
29    field_units,
30    ftype="gas",
31    slice_info=None,
32    validators=None,
33    sampling_type=None,
34):
35
36    axis_order = registry.ds.coordinates.axis_order
37
38    field_components = [(ftype, f"{basename}_{ax}") for ax in axis_order]
39
40    if sampling_type is None:
41        sampling_type = "local"
42
43    def _magnitude(field, data):
44        fn = field_components[0]
45        if data.has_field_parameter(f"bulk_{basename}"):
46            fn = (fn[0], f"relative_{fn[1]}")
47        d = data[fn]
48        mag = (d) ** 2
49        for idim in range(1, registry.ds.dimensionality):
50            fn = field_components[idim]
51            if data.has_field_parameter(f"bulk_{basename}"):
52                fn = (fn[0], f"relative_{fn[1]}")
53            mag += (data[fn]) ** 2
54        return np.sqrt(mag)
55
56    registry.add_field(
57        (ftype, f"{basename}_magnitude"),
58        sampling_type=sampling_type,
59        function=_magnitude,
60        units=field_units,
61        validators=validators,
62    )
63
64
65def create_relative_field(
66    registry, basename, field_units, ftype="gas", slice_info=None, validators=None
67):
68
69    axis_order = registry.ds.coordinates.axis_order
70
71    field_components = [(ftype, f"{basename}_{ax}") for ax in axis_order]
72
73    def relative_vector(ax):
74        def _relative_vector(field, data):
75            iax = axis_order.index(ax)
76            d = data[field_components[iax]]
77            bulk = get_bulk(data, basename, d.unit_quantity)
78            return d - bulk[iax]
79
80        return _relative_vector
81
82    for d in axis_order:
83        registry.add_field(
84            (ftype, f"relative_{basename}_{d}"),
85            sampling_type="local",
86            function=relative_vector(d),
87            units=field_units,
88            validators=validators,
89        )
90
91
92def create_los_field(registry, basename, field_units, ftype="gas", slice_info=None):
93    axis_order = registry.ds.coordinates.axis_order
94
95    validators = [
96        ValidateParameter(f"bulk_{basename}"),
97        ValidateParameter("axis", {"axis": [0, 1, 2]}),
98    ]
99
100    field_comps = [(ftype, f"{basename}_{ax}") for ax in axis_order]
101
102    def _los_field(field, data):
103        if data.has_field_parameter(f"bulk_{basename}"):
104            fns = [(fc[0], f"relative_{fc[1]}") for fc in field_comps]
105        else:
106            fns = field_comps
107        ax = data.get_field_parameter("axis")
108        if is_sequence(ax):
109            # Make sure this is a unit vector
110            ax /= np.sqrt(np.dot(ax, ax))
111            ret = data[fns[0]] * ax[0] + data[fns[1]] * ax[1] + data[fns[2]] * ax[2]
112        elif ax in [0, 1, 2]:
113            ret = data[fns[ax]]
114        else:
115            raise NeedsParameter(["axis"])
116        return ret
117
118    registry.add_field(
119        (ftype, f"{basename}_los"),
120        sampling_type="local",
121        function=_los_field,
122        units=field_units,
123        validators=validators,
124    )
125
126
127def create_squared_field(
128    registry, basename, field_units, ftype="gas", slice_info=None, validators=None
129):
130
131    axis_order = registry.ds.coordinates.axis_order
132
133    field_components = [(ftype, f"{basename}_{ax}") for ax in axis_order]
134
135    def _squared(field, data):
136        fn = field_components[0]
137        if data.has_field_parameter(f"bulk_{basename}"):
138            fn = (fn[0], f"relative_{fn[1]}")
139        squared = data[fn] * data[fn]
140        for idim in range(1, registry.ds.dimensionality):
141            fn = field_components[idim]
142            squared += data[fn] * data[fn]
143        return squared
144
145    registry.add_field(
146        (ftype, f"{basename}_squared"),
147        sampling_type="local",
148        function=_squared,
149        units=field_units,
150        validators=validators,
151    )
152
153
154def create_vector_fields(registry, basename, field_units, ftype="gas", slice_info=None):
155    from yt.units.unit_object import Unit
156
157    # slice_info would be the left, the right, and the factor.
158    # For example, with the old Enzo-ZEUS fields, this would be:
159    # slice(None, -2, None)
160    # slice(1, -1, None)
161    # 1.0
162    # Otherwise, we default to a centered difference.
163    if slice_info is None:
164        sl_left = slice(None, -2, None)
165        sl_right = slice(2, None, None)
166        div_fac = 2.0
167    else:
168        sl_left, sl_right, div_fac = slice_info
169
170    xn, yn, zn = ((ftype, f"{basename}_{ax}") for ax in "xyz")
171
172    # Is this safe?
173    if registry.ds.dimensionality < 3:
174        zn = ("index", "zeros")
175    if registry.ds.dimensionality < 2:
176        yn = ("index", "zeros")
177
178    create_relative_field(
179        registry,
180        basename,
181        field_units,
182        ftype=ftype,
183        slice_info=slice_info,
184        validators=[ValidateParameter(f"bulk_{basename}")],
185    )
186
187    create_magnitude_field(
188        registry,
189        basename,
190        field_units,
191        ftype=ftype,
192        slice_info=slice_info,
193        validators=[ValidateParameter(f"bulk_{basename}")],
194    )
195
196    if not is_curvilinear(registry.ds.geometry):
197
198        # The following fields are invalid for curvilinear geometries
199        def _spherical_radius_component(field, data):
200            """The spherical radius component of the vector field
201
202            Relative to the coordinate system defined by the *normal* vector,
203            *center*, and *bulk_* field parameters.
204            """
205            normal = data.get_field_parameter("normal")
206            vectors = obtain_relative_velocity_vector(
207                data, (xn, yn, zn), f"bulk_{basename}"
208            )
209            theta = data["index", "spherical_theta"]
210            phi = data["index", "spherical_phi"]
211            rv = get_sph_r_component(vectors, theta, phi, normal)
212            # Now, anywhere that radius is in fact zero, we want to zero out our
213            # return values.
214            rv[np.isnan(theta)] = 0.0
215            return rv
216
217        registry.add_field(
218            (ftype, f"{basename}_spherical_radius"),
219            sampling_type="local",
220            function=_spherical_radius_component,
221            units=field_units,
222            validators=[
223                ValidateParameter("normal"),
224                ValidateParameter("center"),
225                ValidateParameter(f"bulk_{basename}"),
226            ],
227        )
228        create_los_field(
229            registry, basename, field_units, ftype=ftype, slice_info=slice_info
230        )
231
232        def _radial(field, data):
233            return data[ftype, f"{basename}_spherical_radius"]
234
235        def _radial_absolute(field, data):
236            return np.abs(data[ftype, f"{basename}_spherical_radius"])
237
238        def _tangential(field, data):
239            return np.sqrt(
240                data[ftype, f"{basename}_spherical_theta"] ** 2.0
241                + data[ftype, f"{basename}_spherical_phi"] ** 2.0
242            )
243
244        registry.add_field(
245            (ftype, f"radial_{basename}"),
246            sampling_type="local",
247            function=_radial,
248            units=field_units,
249            validators=[ValidateParameter("normal"), ValidateParameter("center")],
250        )
251
252        registry.add_field(
253            (ftype, f"radial_{basename}_absolute"),
254            sampling_type="local",
255            function=_radial_absolute,
256            units=field_units,
257        )
258
259        registry.add_field(
260            (ftype, f"tangential_{basename}"),
261            sampling_type="local",
262            function=_tangential,
263            units=field_units,
264        )
265
266        def _spherical_theta_component(field, data):
267            """The spherical theta component of the vector field
268
269            Relative to the coordinate system defined by the *normal* vector,
270            *center*, and *bulk_* field parameters.
271            """
272            normal = data.get_field_parameter("normal")
273            vectors = obtain_relative_velocity_vector(
274                data, (xn, yn, zn), f"bulk_{basename}"
275            )
276            theta = data["index", "spherical_theta"]
277            phi = data["index", "spherical_phi"]
278            return get_sph_theta_component(vectors, theta, phi, normal)
279
280        registry.add_field(
281            (ftype, f"{basename}_spherical_theta"),
282            sampling_type="local",
283            function=_spherical_theta_component,
284            units=field_units,
285            validators=[
286                ValidateParameter("normal"),
287                ValidateParameter("center"),
288                ValidateParameter(f"bulk_{basename}"),
289            ],
290        )
291
292        def _spherical_phi_component(field, data):
293            """The spherical phi component of the vector field
294
295            Relative to the coordinate system defined by the *normal* vector,
296            *center*, and *bulk_* field parameters.
297            """
298            normal = data.get_field_parameter("normal")
299            vectors = obtain_relative_velocity_vector(
300                data, (xn, yn, zn), f"bulk_{basename}"
301            )
302            phi = data["index", "spherical_phi"]
303            return get_sph_phi_component(vectors, phi, normal)
304
305        registry.add_field(
306            (ftype, f"{basename}_spherical_phi"),
307            sampling_type="local",
308            function=_spherical_phi_component,
309            units=field_units,
310            validators=[
311                ValidateParameter("normal"),
312                ValidateParameter("center"),
313                ValidateParameter(f"bulk_{basename}"),
314            ],
315        )
316
317        def _cp_vectors(ax):
318            def _cp_val(field, data):
319                vec = data.get_field_parameter(f"cp_{ax}_vec")
320                tr = data[xn[0], f"relative_{xn[1]}"] * vec.d[0]
321                tr += data[yn[0], f"relative_{yn[1]}"] * vec.d[1]
322                tr += data[zn[0], f"relative_{zn[1]}"] * vec.d[2]
323                return tr
324
325            return _cp_val
326
327        for ax in "xyz":
328            registry.add_field(
329                (ftype, f"cutting_plane_{basename}_{ax}"),
330                sampling_type="local",
331                function=_cp_vectors(ax),
332                units=field_units,
333            )
334
335        def _divergence(field, data):
336            ds = div_fac * just_one(data["index", "dx"])
337            f = data[xn[0], f"relative_{xn[1]}"][sl_right, 1:-1, 1:-1] / ds
338            f -= data[xn[0], f"relative_{xn[1]}"][sl_left, 1:-1, 1:-1] / ds
339            ds = div_fac * just_one(data["index", "dy"])
340            f += data[yn[0], f"relative_{yn[1]}"][1:-1, sl_right, 1:-1] / ds
341            f -= data[yn[0], f"relative_{yn[1]}"][1:-1, sl_left, 1:-1] / ds
342            ds = div_fac * just_one(data["index", "dz"])
343            f += data[zn[0], f"relative_{zn[1]}"][1:-1, 1:-1, sl_right] / ds
344            f -= data[zn[0], f"relative_{zn[1]}"][1:-1, 1:-1, sl_left] / ds
345            new_field = data.ds.arr(np.zeros(data[xn].shape, dtype=np.float64), f.units)
346            new_field[1:-1, 1:-1, 1:-1] = f
347            return new_field
348
349        def _divergence_abs(field, data):
350            return np.abs(data[ftype, f"{basename}_divergence"])
351
352        field_units = Unit(field_units, registry=registry.ds.unit_registry)
353        div_units = field_units / registry.ds.unit_system["length"]
354
355        registry.add_field(
356            (ftype, f"{basename}_divergence"),
357            sampling_type="local",
358            function=_divergence,
359            units=div_units,
360            validators=[ValidateSpatial(1), ValidateParameter(f"bulk_{basename}")],
361        )
362
363        registry.add_field(
364            (ftype, f"{basename}_divergence_absolute"),
365            sampling_type="local",
366            function=_divergence_abs,
367            units=div_units,
368        )
369
370        def _tangential_over_magnitude(field, data):
371            tr = (
372                data[ftype, f"tangential_{basename}"]
373                / data[ftype, f"{basename}_magnitude"]
374            )
375            return np.abs(tr)
376
377        registry.add_field(
378            (ftype, f"tangential_over_{basename}_magnitude"),
379            sampling_type="local",
380            function=_tangential_over_magnitude,
381            take_log=False,
382        )
383
384        def _cylindrical_radius_component(field, data):
385            """The cylindrical radius component of the vector field
386
387            Relative to the coordinate system defined by the *normal* vector,
388            *center*, and *bulk_* field parameters.
389            """
390            normal = data.get_field_parameter("normal")
391            vectors = obtain_relative_velocity_vector(
392                data, (xn, yn, zn), f"bulk_{basename}"
393            )
394            theta = data["index", "cylindrical_theta"]
395            return get_cyl_r_component(vectors, theta, normal)
396
397        registry.add_field(
398            (ftype, f"{basename}_cylindrical_radius"),
399            sampling_type="local",
400            function=_cylindrical_radius_component,
401            units=field_units,
402            validators=[ValidateParameter("normal")],
403        )
404
405        registry.alias(
406            (ftype, f"cylindrical_radial_{basename}"),
407            (ftype, f"{basename}_cylindrical_radius"),
408            deprecate=("4.0.0", "4.1.0"),
409        )
410
411        def _cylindrical_radial_absolute(field, data):
412            """This field is deprecated and will be removed in a future version"""
413            return np.abs(data[ftype, f"{basename}_cylindrical_radius"])
414
415        registry.add_deprecated_field(
416            (ftype, f"cylindrical_radial_{basename}_absolute"),
417            function=_cylindrical_radial_absolute,
418            sampling_type="local",
419            since="4.0.0",
420            removal="4.1.0",
421            units=field_units,
422            validators=[ValidateParameter("normal")],
423        )
424
425        def _cylindrical_theta_component(field, data):
426            """The cylindrical theta component of the vector field
427
428            Relative to the coordinate system defined by the *normal* vector,
429            *center*, and *bulk_* field parameters.
430            """
431            normal = data.get_field_parameter("normal")
432            vectors = obtain_relative_velocity_vector(
433                data, (xn, yn, zn), f"bulk_{basename}"
434            )
435            theta = data["index", "cylindrical_theta"].copy()
436            theta = np.tile(theta, (3,) + (1,) * len(theta.shape))
437            return get_cyl_theta_component(vectors, theta, normal)
438
439        registry.add_field(
440            (ftype, f"{basename}_cylindrical_theta"),
441            sampling_type="local",
442            function=_cylindrical_theta_component,
443            units=field_units,
444            validators=[
445                ValidateParameter("normal"),
446                ValidateParameter("center"),
447                ValidateParameter(f"bulk_{basename}"),
448            ],
449        )
450
451        def _cylindrical_tangential_absolute(field, data):
452            """This field is deprecated and will be removed in a future release"""
453            return np.abs(data[ftype, f"cylindrical_tangential_{basename}"])
454
455        registry.alias(
456            (ftype, f"cylindrical_tangential_{basename}"),
457            (ftype, f"{basename}_cylindrical_theta"),
458            deprecate=("4.0.0", "4.1.0"),
459        )
460
461        registry.add_deprecated_field(
462            (ftype, f"cylindrical_tangential_{basename}_absolute"),
463            function=_cylindrical_tangential_absolute,
464            sampling_type="local",
465            since="4.0.0",
466            removal="4.1.0",
467            units=field_units,
468        )
469
470        def _cylindrical_z_component(field, data):
471            """The cylindrical z component of the vector field
472
473            Relative to the coordinate system defined by the *normal* vector,
474            *center*, and *bulk_* field parameters.
475            """
476            normal = data.get_field_parameter("normal")
477            vectors = obtain_relative_velocity_vector(
478                data, (xn, yn, zn), f"bulk_{basename}"
479            )
480            return get_cyl_z_component(vectors, normal)
481
482        registry.add_field(
483            (ftype, f"{basename}_cylindrical_z"),
484            sampling_type="local",
485            function=_cylindrical_z_component,
486            units=field_units,
487            validators=[
488                ValidateParameter("normal"),
489                ValidateParameter("center"),
490                ValidateParameter(f"bulk_{basename}"),
491            ],
492        )
493
494    else:  # Create Cartesian fields for curvilinear coordinates
495
496        def _cartesian_x(field, data):
497            if registry.ds.geometry == "polar":
498
499                return data[(ftype, f"{basename}_r")] * np.cos(data[(ftype, "theta")])
500
501            elif registry.ds.geometry == "cylindrical":
502
503                if data.ds.dimensionality == 2:
504                    return data[(ftype, f"{basename}_r")]
505                elif data.ds.dimensionality == 3:
506                    return data[(ftype, f"{basename}_r")] * np.cos(
507                        data[(ftype, "theta")]
508                    ) - data[(ftype, f"{basename}_theta")] * np.sin(
509                        data[(ftype, "theta")]
510                    )
511
512            elif registry.ds.geometry == "spherical":
513
514                if data.ds.dimensionality == 2:
515                    return data[(ftype, f"{basename}_r")] * np.sin(
516                        data[(ftype, "theta")]
517                    ) + data[(ftype, f"{basename}_theta")] * np.cos(
518                        data[(ftype, "theta")]
519                    )
520                elif data.ds.dimensionality == 3:
521                    return (
522                        data[(ftype, f"{basename}_r")]
523                        * np.sin(data[(ftype, "theta")])
524                        * np.cos(data[(ftype, "phi")])
525                        + data[(ftype, f"{basename}_theta")]
526                        * np.cos(data[(ftype, "theta")])
527                        * np.cos([(ftype, "phi")])
528                        - data[(ftype, f"{basename}_phi")]
529                        * np.sin(data[(ftype, "phi")])
530                    )
531
532        # it's redundant to define a cartesian x field for 1D data
533        if registry.ds.dimensionality > 1:
534            registry.add_field(
535                (ftype, f"{basename}_cartesian_x"),
536                sampling_type="local",
537                function=_cartesian_x,
538                units=field_units,
539                display_field=True,
540            )
541
542        def _cartesian_y(field, data):
543            if registry.ds.geometry == "polar":
544
545                return data[(ftype, f"{basename}_r")] * np.sin(data[(ftype, "theta")])
546
547            elif registry.ds.geometry == "cylindrical":
548
549                if data.ds.dimensionality == 2:
550                    return data[(ftype, f"{basename}_z")]
551                elif data.ds.dimensionality == 3:
552                    return data[(ftype, f"{basename}_r")] * np.sin(
553                        data[(ftype, "theta")]
554                    ) + data[(ftype, f"{basename}_theta")] * np.cos(
555                        data[(ftype, "theta")]
556                    )
557
558            elif registry.ds.geometry == "spherical":
559
560                if data.ds.dimensionality == 2:
561                    return data[(ftype, f"{basename}_r")] * np.cos(
562                        data[(ftype, "theta")]
563                    ) - data[f"{basename}_theta"] * np.sin(data[(ftype, "theta")])
564                elif data.ds.dimensionality == 3:
565                    return (
566                        data[(ftype, f"{basename}_r")]
567                        * np.sin(data[(ftype, "theta")])
568                        * np.sin(data[(ftype, "phi")])
569                        + data[(ftype, f"{basename}_theta")]
570                        * np.cos(data[(ftype, "theta")])
571                        * np.sin([(ftype, "phi")])
572                        + data[(ftype, f"{basename}_phi")]
573                        * np.cos(data[(ftype, "phi")])
574                    )
575
576        if registry.ds.dimensionality >= 2:
577            registry.add_field(
578                (ftype, f"{basename}_cartesian_y"),
579                sampling_type="local",
580                function=_cartesian_y,
581                units=field_units,
582                display_field=True,
583            )
584
585        def _cartesian_z(field, data):
586            if registry.ds.geometry == "cylindrical":
587                return data[(ftype, f"{basename}_z")]
588            elif registry.ds.geometry == "spherical":
589                return data[(ftype, f"{basename}_r")] * np.cos(
590                    data[(ftype, "theta")]
591                ) - data[(ftype, f"{basename}_theta")] * np.sin(data[(ftype, "theta")])
592
593        if registry.ds.dimensionality == 3:
594            registry.add_field(
595                (ftype, f"{basename}_cartesian_z"),
596                sampling_type="local",
597                function=_cartesian_z,
598                units=field_units,
599                display_field=True,
600            )
601
602
603def create_averaged_field(
604    registry,
605    basename,
606    field_units,
607    ftype="gas",
608    slice_info=None,
609    validators=None,
610    weight="mass",
611):
612
613    if validators is None:
614        validators = []
615    validators += [ValidateSpatial(1, [(ftype, basename)])]
616
617    def _averaged_field(field, data):
618        def atleast_4d(array):
619            if array.ndim == 3:
620                return array[..., None]
621            else:
622                return array
623
624        nx, ny, nz, ngrids = atleast_4d(data[(ftype, basename)]).shape
625        new_field = data.ds.arr(
626            np.zeros((nx - 2, ny - 2, nz - 2, ngrids), dtype=np.float64),
627            (just_one(data[(ftype, basename)]) * just_one(data[(ftype, weight)])).units,
628        )
629        weight_field = data.ds.arr(
630            np.zeros((nx - 2, ny - 2, nz - 2, ngrids), dtype=np.float64),
631            data[(ftype, weight)].units,
632        )
633        i_i, j_i, k_i = np.mgrid[0:3, 0:3, 0:3]
634
635        for i, j, k in zip(i_i.ravel(), j_i.ravel(), k_i.ravel()):
636            sl = (
637                slice(i, nx - (2 - i)),
638                slice(j, ny - (2 - j)),
639                slice(k, nz - (2 - k)),
640            )
641            new_field += (
642                atleast_4d(data[(ftype, basename)])[sl]
643                * atleast_4d(data[(ftype, weight)])[sl]
644            )
645            weight_field += atleast_4d(data[(ftype, weight)])[sl]
646
647        # Now some fancy footwork
648        new_field2 = data.ds.arr(
649            np.zeros((nx, ny, nz, ngrids)), data[(ftype, basename)].units
650        )
651        new_field2[1:-1, 1:-1, 1:-1] = new_field / weight_field
652
653        if data[(ftype, basename)].ndim == 3:
654            return new_field2[..., 0]
655        else:
656            return new_field2
657
658    registry.add_field(
659        (ftype, f"averaged_{basename}"),
660        sampling_type="cell",
661        function=_averaged_field,
662        units=field_units,
663        validators=validators,
664    )
665