1from yt.fields.local_fields import add_field 2from yt.utilities.linear_interpolators import ( 3 BilinearFieldInterpolator, 4 TrilinearFieldInterpolator, 5 UnilinearFieldInterpolator, 6) 7 8_int_class = { 9 1: UnilinearFieldInterpolator, 10 2: BilinearFieldInterpolator, 11 3: TrilinearFieldInterpolator, 12} 13 14 15def add_interpolated_field( 16 name, 17 units, 18 table_data, 19 axes_data, 20 axes_fields, 21 ftype="gas", 22 particle_type=False, 23 validators=None, 24 truncate=True, 25): 26 27 if len(table_data.shape) not in _int_class: 28 raise RuntimeError( 29 "Interpolated field can only be created from 1d, 2d, or 3d data." 30 ) 31 32 if len(axes_fields) != len(axes_data) or len(axes_fields) != len(table_data.shape): 33 raise RuntimeError( 34 "Data dimension mismatch: data is %d, " 35 "%d axes data provided, and %d axes fields provided." 36 % (len(table_data.shape), len(axes_data), len(axes_fields)) 37 ) 38 39 int_class = _int_class[len(table_data.shape)] 40 my_interpolator = int_class(table_data, axes_data, axes_fields, truncate=truncate) 41 42 def _interpolated_field(field, data): 43 return my_interpolator(data) 44 45 add_field( 46 (ftype, name), 47 function=_interpolated_field, 48 units=units, 49 validators=validators, 50 particle_type=particle_type, 51 ) 52