1from yt.fields.local_fields import add_field
2from yt.utilities.linear_interpolators import (
3    BilinearFieldInterpolator,
4    TrilinearFieldInterpolator,
5    UnilinearFieldInterpolator,
6)
7
8_int_class = {
9    1: UnilinearFieldInterpolator,
10    2: BilinearFieldInterpolator,
11    3: TrilinearFieldInterpolator,
12}
13
14
15def add_interpolated_field(
16    name,
17    units,
18    table_data,
19    axes_data,
20    axes_fields,
21    ftype="gas",
22    particle_type=False,
23    validators=None,
24    truncate=True,
25):
26
27    if len(table_data.shape) not in _int_class:
28        raise RuntimeError(
29            "Interpolated field can only be created from 1d, 2d, or 3d data."
30        )
31
32    if len(axes_fields) != len(axes_data) or len(axes_fields) != len(table_data.shape):
33        raise RuntimeError(
34            "Data dimension mismatch: data is %d, "
35            "%d axes data provided, and %d axes fields provided."
36            % (len(table_data.shape), len(axes_data), len(axes_fields))
37        )
38
39    int_class = _int_class[len(table_data.shape)]
40    my_interpolator = int_class(table_data, axes_data, axes_fields, truncate=truncate)
41
42    def _interpolated_field(field, data):
43        return my_interpolator(data)
44
45    add_field(
46        (ftype, name),
47        function=_interpolated_field,
48        units=units,
49        validators=validators,
50        particle_type=particle_type,
51    )
52