1import numpy as np
2
3from yt.data_objects.api import ImageArray
4from yt.funcs import is_sequence, mylog
5from yt.units.unit_object import Unit
6from yt.utilities.lib.partitioned_grid import PartitionedGrid
7from yt.utilities.lib.pixelization_routines import (
8    normalization_2d_utility,
9    off_axis_projection_SPH,
10)
11
12from .render_source import KDTreeVolumeSource
13from .scene import Scene
14from .transfer_functions import ProjectionTransferFunction
15from .utils import data_source_or_all
16
17
18def off_axis_projection(
19    data_source,
20    center,
21    normal_vector,
22    width,
23    resolution,
24    item,
25    weight=None,
26    volume=None,
27    no_ghost=False,
28    interpolated=False,
29    north_vector=None,
30    num_threads=1,
31    method="integrate",
32):
33    r"""Project through a dataset, off-axis, and return the image plane.
34
35    This function will accept the necessary items to integrate through a volume
36    at an arbitrary angle and return the integrated field of view to the user.
37    Note that if a weight is supplied, it will multiply the pre-interpolated
38    values together, then create cell-centered values, then interpolate within
39    the cell to conduct the integration.
40
41    Parameters
42    ----------
43    data_source : ~yt.data_objects.static_output.Dataset
44                  or ~yt.data_objects.data_containers.YTSelectionDataContainer
45        This is the dataset or data object to volume render.
46    center : array_like
47        The current 'center' of the view port -- the focal point for the
48        camera.
49    normal_vector : array_like
50        The vector between the camera position and the center.
51    width : float or list of floats
52        The current width of the image.  If a single float, the volume is
53        cubical, but if not, it is left/right, top/bottom, front/back
54    resolution : int or list of ints
55        The number of pixels in each direction.
56    item: string
57        The field to project through the volume
58    weight : optional, default None
59        If supplied, the field will be pre-multiplied by this, then divided by
60        the integrated value of this field.  This returns an average rather
61        than a sum.
62    volume : `yt.extensions.volume_rendering.AMRKDTree`, optional
63        The volume to ray cast through.  Can be specified for finer-grained
64        control, but otherwise will be automatically generated.
65    no_ghost: bool, optional
66        Optimization option.  If True, homogenized bricks will
67        extrapolate out from grid instead of interpolating from
68        ghost zones that have to first be calculated.  This can
69        lead to large speed improvements, but at a loss of
70        accuracy/smoothness in resulting image.  The effects are
71        less notable when the transfer function is smooth and
72        broad. Default: True
73    interpolated : optional, default False
74        If True, the data is first interpolated to vertex-centered data,
75        then tri-linearly interpolated along the ray. Not suggested for
76        quantitative studies.
77    north_vector : optional, array_like, default None
78        A vector that, if specified, restricts the orientation such that the
79        north vector dotted into the image plane points "up". Useful for rotations
80    num_threads: integer, optional, default 1
81        Use this many OpenMP threads during projection.
82    method : string
83        The method of projection.  Valid methods are:
84
85        "integrate" with no weight_field specified : integrate the requested
86        field along the line of sight.
87
88        "integrate" with a weight_field specified : weight the requested
89        field by the weighting field and integrate along the line of sight.
90
91        "sum" : This method is the same as integrate, except that it does not
92        multiply by a path length when performing the integration, and is
93        just a straight summation of the field along the given axis. WARNING:
94        This should only be used for uniform resolution grid datasets, as other
95        datasets may result in unphysical images.
96        or camera movements.
97    Returns
98    -------
99    image : array
100        An (N,N) array of the final integrated values, in float64 form.
101
102    Examples
103    --------
104
105    >>> image = off_axis_projection(
106    ...     ds,
107    ...     [0.5, 0.5, 0.5],
108    ...     [0.2, 0.3, 0.4],
109    ...     0.2,
110    ...     N,
111    ...     ("gas", "temperature"),
112    ...     ("gas", "density"),
113    ... )
114    >>> write_image(np.log10(image), "offaxis.png")
115
116    """
117    if method not in ("integrate", "sum"):
118        raise NotImplementedError(
119            "Only 'integrate' or 'sum' methods are valid for off-axis-projections"
120        )
121
122    if interpolated:
123        raise NotImplementedError(
124            "Only interpolated=False methods are currently implemented "
125            "for off-axis-projections"
126        )
127
128    data_source = data_source_or_all(data_source)
129
130    item = data_source._determine_fields([item])[0]
131
132    # Assure vectors are numpy arrays as expected by cython code
133    normal_vector = np.array(normal_vector, dtype="float64")
134    if north_vector is not None:
135        north_vector = np.array(north_vector, dtype="float64")
136    # Add the normal as a field parameter to the data source
137    # so line of sight fields can use it
138    data_source.set_field_parameter("axis", normal_vector)
139
140    # Sanitize units
141    if not hasattr(center, "units"):
142        center = data_source.ds.arr(center, "code_length")
143    if not hasattr(width, "units"):
144        width = data_source.ds.arr(width, "code_length")
145
146    if hasattr(data_source.ds, "_sph_ptypes"):
147        if method != "integrate":
148            raise NotImplementedError("SPH Only allows 'integrate' method")
149
150        sph_ptypes = data_source.ds._sph_ptypes
151        fi = data_source.ds.field_info[item]
152
153        raise_error = False
154
155        ptype = sph_ptypes[0]
156        ppos = [f"particle_position_{ax}" for ax in "xyz"]
157        # Assure that the field we're trying to off-axis project
158        # has a field type as the SPH particle type or if the field is an
159        # alias to an SPH field or is a 'gas' field
160        if item[0] in data_source.ds.known_filters:
161            if item[0] not in sph_ptypes:
162                raise_error = True
163            else:
164                ptype = item[0]
165                ppos = ["x", "y", "z"]
166        elif fi.alias_field:
167            if fi.alias_name[0] not in sph_ptypes:
168                raise_error = True
169            elif item[0] != "gas":
170                ptype = item[0]
171        else:
172            if fi.name[0] not in sph_ptypes and fi.name[0] != "gas":
173                raise_error = True
174
175        if raise_error:
176            raise RuntimeError(
177                "Can only perform off-axis projections for SPH fields, "
178                "Received '%s'" % (item,)
179            )
180
181        normal = np.array(normal_vector)
182        normal = normal / np.linalg.norm(normal)
183
184        # If north_vector is None, we set the default here.
185        # This is chosen so that if normal_vector is one of the
186        # cartesian coordinate axes, the projection will match
187        # the corresponding on-axis projection.
188        if north_vector is None:
189            vecs = np.identity(3)
190            t = np.cross(vecs, normal).sum(axis=1)
191            ax = t.argmax()
192            east_vector = np.cross(vecs[ax, :], normal).ravel()
193            north = np.cross(normal, east_vector).ravel()
194        else:
195            north = np.array(north_vector)
196            north = north / np.linalg.norm(north)
197            east_vector = np.cross(north, normal).ravel()
198
199        # if weight is None:
200        buf = np.zeros((resolution[0], resolution[1]), dtype="float64")
201
202        x_min = center[0] - width[0] / 2
203        x_max = center[0] + width[0] / 2
204        y_min = center[1] - width[1] / 2
205        y_max = center[1] + width[1] / 2
206        z_min = center[2] - width[2] / 2
207        z_max = center[2] + width[2] / 2
208        finfo = data_source.ds.field_info[item]
209        ounits = finfo.output_units
210        bounds = [x_min, x_max, y_min, y_max, z_min, z_max]
211
212        if weight is None:
213            for chunk in data_source.chunks([], "io"):
214                off_axis_projection_SPH(
215                    chunk[ptype, ppos[0]].to("code_length").d,
216                    chunk[ptype, ppos[1]].to("code_length").d,
217                    chunk[ptype, ppos[2]].to("code_length").d,
218                    chunk[ptype, "mass"].to("code_mass").d,
219                    chunk[ptype, "density"].to("code_density").d,
220                    chunk[ptype, "smoothing_length"].to("code_length").d,
221                    bounds,
222                    center.to("code_length").d,
223                    width.to("code_length").d,
224                    chunk[item].in_units(ounits),
225                    buf,
226                    normal_vector,
227                    north,
228                )
229
230            # Assure that the path length unit is in the default length units
231            # for the dataset by scaling the units of the smoothing length
232            path_length_unit = data_source.ds._get_field_info(
233                (ptype, "smoothing_length")
234            ).units
235            path_length_unit = Unit(
236                path_length_unit, registry=data_source.ds.unit_registry
237            )
238            default_path_length_unit = data_source.ds.unit_system["length"]
239            buf *= data_source.ds.quan(1, path_length_unit).in_units(
240                default_path_length_unit
241            )
242            item_unit = data_source.ds._get_field_info(item).units
243            item_unit = Unit(item_unit, registry=data_source.ds.unit_registry)
244            funits = item_unit * default_path_length_unit
245
246        else:
247            # if there is a weight field, take two projections:
248            # one of field*weight, the other of just weight, and divide them
249            weight_buff = np.zeros((resolution[0], resolution[1]), dtype="float64")
250            wounits = data_source.ds.field_info[weight].output_units
251
252            for chunk in data_source.chunks([], "io"):
253                off_axis_projection_SPH(
254                    chunk[ptype, ppos[0]].to("code_length").d,
255                    chunk[ptype, ppos[1]].to("code_length").d,
256                    chunk[ptype, ppos[2]].to("code_length").d,
257                    chunk[ptype, "mass"].to("code_mass").d,
258                    chunk[ptype, "density"].to("code_density").d,
259                    chunk[ptype, "smoothing_length"].to("code_length").d,
260                    bounds,
261                    center.to("code_length").d,
262                    width.to("code_length").d,
263                    chunk[item].in_units(ounits),
264                    buf,
265                    normal_vector,
266                    north,
267                    weight_field=chunk[weight].in_units(wounits),
268                )
269
270            for chunk in data_source.chunks([], "io"):
271                off_axis_projection_SPH(
272                    chunk[ptype, ppos[0]].to("code_length").d,
273                    chunk[ptype, ppos[1]].to("code_length").d,
274                    chunk[ptype, ppos[2]].to("code_length").d,
275                    chunk[ptype, "mass"].to("code_mass").d,
276                    chunk[ptype, "density"].to("code_density").d,
277                    chunk[ptype, "smoothing_length"].to("code_length").d,
278                    bounds,
279                    center.to("code_length").d,
280                    width.to("code_length").d,
281                    chunk[weight].to(wounits),
282                    weight_buff,
283                    normal_vector,
284                    north,
285                )
286
287            normalization_2d_utility(buf, weight_buff)
288            item_unit = data_source.ds._get_field_info(item).units
289            item_unit = Unit(item_unit, registry=data_source.ds.unit_registry)
290            funits = item_unit
291
292        myinfo = {
293            "field": item,
294            "east_vector": east_vector,
295            "north_vector": north_vector,
296            "normal_vector": normal_vector,
297            "width": width,
298            "units": funits,
299            "type": "SPH smoothed projection",
300        }
301
302        return ImageArray(
303            buf, funits, registry=data_source.ds.unit_registry, info=myinfo
304        )
305
306    sc = Scene()
307    data_source.ds.index
308    if item is None:
309        field = data_source.ds.field_list[0]
310        mylog.info("Setting default field to %s", field.__repr__())
311
312    funits = data_source.ds._get_field_info(item).units
313
314    vol = KDTreeVolumeSource(data_source, item)
315    vol.num_threads = num_threads
316    if weight is None:
317        vol.set_field(item)
318    else:
319        # This is a temporary field, which we will remove at the end.
320        weightfield = ("index", "temp_weightfield")
321
322        def _make_wf(f, w):
323            def temp_weightfield(a, b):
324                tr = b[f].astype("float64") * b[w]
325                return tr.d
326
327            return temp_weightfield
328
329        data_source.ds.field_info.add_field(
330            weightfield, sampling_type="cell", function=_make_wf(item, weight)
331        )
332        # Now we have to tell the dataset to add it and to calculate
333        # its dependencies..
334        deps, _ = data_source.ds.field_info.check_derived_fields([weightfield])
335        data_source.ds.field_dependencies.update(deps)
336        vol.set_field(weightfield)
337        vol.set_weight_field(weight)
338    ptf = ProjectionTransferFunction()
339    vol.set_transfer_function(ptf)
340    camera = sc.add_camera(data_source)
341    camera.set_width(width)
342    if not is_sequence(resolution):
343        resolution = [resolution] * 2
344    camera.resolution = resolution
345    if not is_sequence(width):
346        width = data_source.ds.arr([width] * 3)
347    normal = np.array(normal_vector)
348    normal = normal / np.linalg.norm(normal)
349
350    camera.position = center - width[2] * normal
351    camera.focus = center
352
353    # If north_vector is None, we set the default here.
354    # This is chosen so that if normal_vector is one of the
355    # cartesian coordinate axes, the projection will match
356    # the corresponding on-axis projection.
357    if north_vector is None:
358        vecs = np.identity(3)
359        t = np.cross(vecs, normal).sum(axis=1)
360        ax = t.argmax()
361        east_vector = np.cross(vecs[ax, :], normal).ravel()
362        north = np.cross(normal, east_vector).ravel()
363    else:
364        north = np.array(north_vector)
365        north = north / np.linalg.norm(north)
366    camera.switch_orientation(normal, north)
367
368    sc.add_source(vol)
369
370    vol.set_sampler(camera, interpolated=False)
371    assert vol.sampler is not None
372
373    fields = [vol.field]
374    if vol.weight_field is not None:
375        fields.append(vol.weight_field)
376
377    mylog.debug("Casting rays")
378
379    for (grid, mask) in data_source.blocks:
380        data = []
381        for f in fields:
382            # strip units before multiplying by mask for speed
383            grid_data = grid[f]
384            units = grid_data.units
385            data.append(data_source.ds.arr(grid_data.d * mask, units, dtype="float64"))
386        pg = PartitionedGrid(
387            grid.id,
388            data,
389            mask.astype("uint8"),
390            grid.LeftEdge,
391            grid.RightEdge,
392            grid.ActiveDimensions.astype("int64"),
393        )
394        grid.clear_data()
395        vol.sampler(pg, num_threads=num_threads)
396
397    image = vol.finalize_image(camera, vol.sampler.aimage)
398    image = ImageArray(
399        image, funits, registry=data_source.ds.unit_registry, info=image.info
400    )
401
402    if weight is not None:
403        data_source.ds.field_info.pop(("index", "temp_weightfield"))
404
405    if method == "integrate":
406        if weight is None:
407            dl = width[2].in_units(data_source.ds.unit_system["length"])
408            image *= dl
409        else:
410            mask = image[:, :, 1] == 0
411            image[:, :, 0] /= image[:, :, 1]
412            image[mask] = 0
413
414    return image[:, :, 0]
415