1import builtins
2from copy import deepcopy
3
4import numpy as np
5
6from yt.config import ytcfg
7from yt.data_objects.api import ImageArray
8from yt.funcs import ensure_numpy_array, get_num_threads, get_pbar, is_sequence, mylog
9from yt.geometry.geometry_handler import cached_property
10from yt.units.yt_array import YTArray
11from yt.utilities.amr_kdtree.api import AMRKDTree
12from yt.utilities.exceptions import YTNotInsideNotebook
13from yt.utilities.lib.grid_traversal import (
14    arr_fisheye_vectors,
15    arr_pix2vec_nest,
16    pixelize_healpix,
17)
18from yt.utilities.lib.image_samplers import (
19    InterpolatedProjectionSampler,
20    LightSourceRenderSampler,
21    ProjectionSampler,
22    VolumeRenderSampler,
23)
24from yt.utilities.lib.misc_utilities import lines
25from yt.utilities.lib.partitioned_grid import PartitionedGrid
26from yt.utilities.math_utils import get_rotation_matrix
27from yt.utilities.object_registries import data_object_registry
28from yt.utilities.orientation import Orientation
29from yt.utilities.parallel_tools.parallel_analysis_interface import (
30    ParallelAnalysisInterface,
31    parallel_objects,
32)
33from yt.visualization.image_writer import apply_colormap, write_bitmap, write_image
34from yt.visualization.volume_rendering.blenders import enhance_rgba
35
36from .transfer_functions import ProjectionTransferFunction
37
38
39def get_corners(le, re):
40    return np.array(
41        [
42            [le[0], le[1], le[2]],
43            [re[0], le[1], le[2]],
44            [re[0], re[1], le[2]],
45            [le[0], re[1], le[2]],
46            [le[0], le[1], re[2]],
47            [re[0], le[1], re[2]],
48            [re[0], re[1], re[2]],
49            [le[0], re[1], re[2]],
50        ],
51        dtype="float64",
52    )
53
54
55class Camera(ParallelAnalysisInterface):
56    r"""A viewpoint into a volume, for volume rendering.
57
58    The camera represents the eye of an observer, which will be used to
59    generate ray-cast volume renderings of the domain.
60
61    Parameters
62    ----------
63    center : array_like
64        The current "center" of the view port -- the focal point for the
65        camera.
66    normal_vector : array_like
67        The vector between the camera position and the center.
68    width : float or list of floats
69        The current width of the image.  If a single float, the volume is
70        cubical, but if not, it is left/right, top/bottom, front/back.
71    resolution : int or list of ints
72        The number of pixels in each direction.
73    transfer_function : `yt.visualization.volume_rendering.TransferFunction`
74        The transfer function used to map values to colors in an image.  If
75        not specified, defaults to a ProjectionTransferFunction.
76    north_vector : array_like, optional
77        The 'up' direction for the plane of rays.  If not specific, calculated
78        automatically.
79    steady_north : bool, optional
80        Boolean to control whether to normalize the north_vector
81        by subtracting off the dot product of it and the normal
82        vector.  Makes it easier to do rotations along a single
83        axis.  If north_vector is specified, is switched to
84        True. Default: False
85    volume : `yt.extensions.volume_rendering.AMRKDTree`, optional
86        The volume to ray cast through.  Can be specified for finer-grained
87        control, but otherwise will be automatically generated.
88    fields : list of fields, optional
89        This is the list of fields we want to volume render; defaults to
90        Density.
91    log_fields : list of bool, optional
92        Whether we should take the log of the fields before supplying them to
93        the volume rendering mechanism.
94    sub_samples : int, optional
95        The number of samples to take inside every cell per ray.
96    ds : ~yt.data_objects.static_output.Dataset
97        For now, this is a require parameter!  But in the future it will become
98        optional.  This is the dataset to volume render.
99    max_level: int, optional
100        Specifies the maximum level to be rendered.  Also
101        specifies the maximum level used in the kd-Tree
102        construction.  Defaults to None (all levels), and only
103        applies if use_kd=True.
104    no_ghost: bool, optional
105        Optimization option.  If True, homogenized bricks will
106        extrapolate out from grid instead of interpolating from
107        ghost zones that have to first be calculated.  This can
108        lead to large speed improvements, but at a loss of
109        accuracy/smoothness in resulting image.  The effects are
110        less notable when the transfer function is smooth and
111        broad. Default: True
112    data_source: data container, optional
113        Optionally specify an arbitrary data source to the volume rendering.
114        All cells not included in the data source will be ignored during ray
115        casting. By default this will get set to ds.all_data().
116
117    Examples
118    --------
119
120    >>> from yt.mods import *
121    >>> import yt.visualization.volume_rendering.api as vr
122
123    >>> ds = load("DD1701")  # Load a dataset
124    >>> c = [0.5] * 3  # Center
125    >>> L = [1.0, 1.0, 1.0]  # Viewpoint
126    >>> W = np.sqrt(3)  # Width
127    >>> N = 1024  # Pixels (1024^2)
128
129    # Get density min, max
130    >>> mi, ma = ds.all_data().quantities["Extrema"]("Density")[0]
131    >>> mi, ma = np.log10(mi), np.log10(ma)
132
133    # Construct transfer function
134    >>> tf = vr.ColorTransferFunction((mi - 2, ma + 2))
135    # Sample transfer function with 5 gaussians.  Use new col_bounds keyword.
136    >>> tf.add_layers(5, w=0.05, col_bounds=(mi + 1, ma), colormap="spectral")
137
138    # Create the camera object
139    >>> cam = vr.Camera(c, L, W, (N, N), transfer_function=tf, ds=ds)
140
141    # Ray cast, and save the image.
142    >>> image = cam.snapshot(fn="my_rendering.png")
143
144    """
145    _sampler_object = VolumeRenderSampler
146    _tf_figure = None
147    _render_figure = None
148
149    def __init__(
150        self,
151        center,
152        normal_vector,
153        width,
154        resolution,
155        transfer_function=None,
156        north_vector=None,
157        steady_north=False,
158        volume=None,
159        fields=None,
160        log_fields=None,
161        sub_samples=5,
162        ds=None,
163        min_level=None,
164        max_level=None,
165        no_ghost=True,
166        data_source=None,
167        use_light=False,
168    ):
169        ParallelAnalysisInterface.__init__(self)
170        if ds is not None:
171            self.ds = ds
172        if not is_sequence(resolution):
173            resolution = (resolution, resolution)
174        self.resolution = resolution
175        self.sub_samples = sub_samples
176        self.rotation_vector = north_vector
177        if is_sequence(width) and len(width) > 1 and isinstance(width[1], str):
178            width = self.ds.quan(width[0], units=width[1])
179            # Now convert back to code length for subsequent manipulation
180            width = width.in_units("code_length").value
181        if not is_sequence(width):
182            width = (width, width, width)  # left/right, top/bottom, front/back
183        if not isinstance(width, YTArray):
184            width = self.ds.arr(width, units="code_length")
185        if not isinstance(center, YTArray):
186            center = self.ds.arr(center, units="code_length")
187        # Ensure that width and center are in the same units
188        # Cf. https://bitbucket.org/yt_analysis/yt/issue/1080
189        width.convert_to_units("code_length")
190        center.convert_to_units("code_length")
191        self.orienter = Orientation(
192            normal_vector, north_vector=north_vector, steady_north=steady_north
193        )
194        if not steady_north:
195            self.rotation_vector = self.orienter.unit_vectors[1]
196        self._setup_box_properties(width, center, self.orienter.unit_vectors)
197        if fields is None:
198            fields = [("gas", "density")]
199        self.fields = fields
200        if transfer_function is None:
201            transfer_function = ProjectionTransferFunction()
202        self.transfer_function = transfer_function
203        self.log_fields = log_fields
204        dd = self.ds.all_data()
205        efields = dd._determine_fields(self.fields)
206        if self.log_fields is None:
207            self.log_fields = [self.ds._get_field_info(*f).take_log for f in efields]
208        self.no_ghost = no_ghost
209        self.use_light = use_light
210        self.light_dir = None
211        self.light_rgba = None
212        if self.no_ghost:
213            mylog.warning(
214                "no_ghost is currently True (default). "
215                "This may lead to artifacts at grid boundaries."
216            )
217
218        if data_source is None:
219            data_source = self.ds.all_data()
220        self.data_source = data_source
221
222        if volume is None:
223            volume = AMRKDTree(
224                self.ds,
225                min_level=min_level,
226                max_level=max_level,
227                data_source=self.data_source,
228            )
229        self.volume = volume
230
231    def _setup_box_properties(self, width, center, unit_vectors):
232        self.width = width
233        self.center = center
234        self.box_vectors = YTArray(
235            [
236                unit_vectors[0] * width[0],
237                unit_vectors[1] * width[1],
238                unit_vectors[2] * width[2],
239            ]
240        )
241        self.origin = center - 0.5 * width.dot(YTArray(unit_vectors, ""))
242        self.back_center = center - 0.5 * width[2] * unit_vectors[2]
243        self.front_center = center + 0.5 * width[2] * unit_vectors[2]
244
245    def update_view_from_matrix(self, mat):
246        pass
247
248    def project_to_plane(self, pos, res=None):
249        if res is None:
250            res = self.resolution
251        dx = np.dot(pos - self.origin, self.orienter.unit_vectors[1])
252        dy = np.dot(pos - self.origin, self.orienter.unit_vectors[0])
253        dz = np.dot(pos - self.center, self.orienter.unit_vectors[2])
254        # Transpose into image coords.
255        py = (res[0] * (dx / self.width[0])).astype("int")
256        px = (res[1] * (dy / self.width[1])).astype("int")
257        return px, py, dz
258
259    def draw_grids(self, im, alpha=0.3, cmap=None, min_level=None, max_level=None):
260        r"""Draws Grids on an existing volume rendering.
261
262        By mapping grid level to a color, draws edges of grids on
263        a volume rendering using the camera orientation.
264
265        Parameters
266        ----------
267        im: Numpy ndarray
268            Existing image that has the same resolution as the Camera,
269            which will be painted by grid lines.
270        alpha : float, optional
271            The alpha value for the grids being drawn.  Used to control
272            how bright the grid lines are with respect to the image.
273            Default : 0.3
274        cmap : string, optional
275            Colormap to be used mapping grid levels to colors.
276        min_level : int, optional
277            Optional parameter to specify the min level grid boxes
278            to overplot on the image.
279        max_level : int, optional
280            Optional parameters to specify the max level grid boxes
281            to overplot on the image.
282
283        Returns
284        -------
285        None
286
287        Examples
288        --------
289        >>> im = cam.snapshot()
290        >>> cam.add_grids(im)
291        >>> write_bitmap(im, "render_with_grids.png")
292
293        """
294        if cmap is None:
295            cmap = ytcfg.get("yt", "default_colormap")
296        region = self.data_source
297        corners = []
298        levels = []
299        for block, _mask in region.blocks:
300            block_corners = np.array(
301                [
302                    [block.LeftEdge[0], block.LeftEdge[1], block.LeftEdge[2]],
303                    [block.RightEdge[0], block.LeftEdge[1], block.LeftEdge[2]],
304                    [block.RightEdge[0], block.RightEdge[1], block.LeftEdge[2]],
305                    [block.LeftEdge[0], block.RightEdge[1], block.LeftEdge[2]],
306                    [block.LeftEdge[0], block.LeftEdge[1], block.RightEdge[2]],
307                    [block.RightEdge[0], block.LeftEdge[1], block.RightEdge[2]],
308                    [block.RightEdge[0], block.RightEdge[1], block.RightEdge[2]],
309                    [block.LeftEdge[0], block.RightEdge[1], block.RightEdge[2]],
310                ],
311                dtype="float64",
312            )
313            corners.append(block_corners)
314            levels.append(block.Level)
315        corners = np.dstack(corners)
316        levels = np.array(levels)
317
318        if max_level is not None:
319            subset = levels <= max_level
320            levels = levels[subset]
321            corners = corners[:, :, subset]
322        if min_level is not None:
323            subset = levels >= min_level
324            levels = levels[subset]
325            corners = corners[:, :, subset]
326
327        colors = (
328            apply_colormap(
329                levels * 1.0, color_bounds=[0, self.ds.index.max_level], cmap_name=cmap
330            )[0, :, :]
331            * 1.0
332            / 255.0
333        )
334        colors[:, 3] = alpha
335
336        order = [0, 1, 1, 2, 2, 3, 3, 0]
337        order += [4, 5, 5, 6, 6, 7, 7, 4]
338        order += [0, 4, 1, 5, 2, 6, 3, 7]
339
340        vertices = np.empty([corners.shape[2] * 2 * 12, 3])
341        vertices = self.ds.arr(vertices, "code_length")
342        for i in range(3):
343            vertices[:, i] = corners[order, i, ...].ravel(order="F")
344
345        px, py, dz = self.project_to_plane(vertices, res=im.shape[:2])
346
347        # Must normalize the image
348        nim = im.rescale(inline=False)
349        enhance_rgba(nim)
350        nim.add_background_color("black", inline=True)
351
352        # we flipped it in snapshot to get the orientation correct, so
353        # flip the lines
354        lines(nim.d, px.d, py.d, colors, 24, flip=1)
355
356        return nim
357
358    def draw_coordinate_vectors(self, im, length=0.05, thickness=1):
359        r"""Draws three coordinate vectors in the corner of a rendering.
360
361        Modifies an existing image to have three lines corresponding to the
362        coordinate directions colored by {x,y,z} = {r,g,b}.  Currently only
363        functional for plane-parallel volume rendering.
364
365        Parameters
366        ----------
367        im: Numpy ndarray
368            Existing image that has the same resolution as the Camera,
369            which will be painted by grid lines.
370        length: float, optional
371            The length of the lines, as a fraction of the image size.
372            Default : 0.05
373        thickness : int, optional
374            Thickness in pixels of the line to be drawn.
375
376        Returns
377        -------
378        None
379
380        Modifies
381        --------
382        im: The original image.
383
384        Examples
385        --------
386        >>> im = cam.snapshot()
387        >>> cam.draw_coordinate_vectors(im)
388        >>> im.write_png("render_with_grids.png")
389
390        """
391        length_pixels = length * self.resolution[0]
392        # Put the starting point in the lower left
393        px0 = int(length * self.resolution[0])
394        # CS coordinates!
395        py0 = int((1.0 - length) * self.resolution[1])
396
397        alpha = im[:, :, 3].max()
398        if alpha == 0.0:
399            alpha = 1.0
400
401        coord_vectors = [
402            np.array([length_pixels, 0.0, 0.0]),
403            np.array([0.0, length_pixels, 0.0]),
404            np.array([0.0, 0.0, length_pixels]),
405        ]
406        colors = [
407            np.array([1.0, 0.0, 0.0, alpha]),
408            np.array([0.0, 1.0, 0.0, alpha]),
409            np.array([0.0, 0.0, 1.0, alpha]),
410        ]
411
412        # we flipped it in snapshot to get the orientation correct, so
413        # flip the lines
414        for vec, color in zip(coord_vectors, colors):
415            dx = int(np.dot(vec, self.orienter.unit_vectors[0]))
416            dy = int(np.dot(vec, self.orienter.unit_vectors[1]))
417            px = np.array([px0, px0 + dx], dtype="int64")
418            py = np.array([py0, py0 + dy], dtype="int64")
419            lines(im.d, px, py, np.array([color, color]), 1, thickness, flip=1)
420
421    def draw_line(self, im, x0, x1, color=None):
422        r"""Draws a line on an existing volume rendering.
423        Given starting and ending positions x0 and x1, draws a line on
424        a volume rendering using the camera orientation.
425
426        Parameters
427        ----------
428        im : ImageArray or 2D ndarray
429            Existing image that has the same resolution as the Camera,
430            which will be painted by grid lines.
431        x0 : YTArray or ndarray
432            Starting coordinate.  If passed in as an ndarray,
433            assumed to be in code units.
434        x1 : YTArray or ndarray
435            Ending coordinate, in simulation coordinates.  If passed in as
436            an ndarray, assumed to be in code units.
437        color : array like, optional
438            Color of the line (r, g, b, a). Defaults to white.
439
440        Returns
441        -------
442        None
443
444        Examples
445        --------
446        >>> im = cam.snapshot()
447        >>> cam.draw_line(im, np.array([0.1, 0.2, 0.3]), np.array([0.5, 0.6, 0.7]))
448        >>> write_bitmap(im, "render_with_line.png")
449
450        """
451        if color is None:
452            color = np.array([1.0, 1.0, 1.0, 1.0])
453
454        if not hasattr(x0, "units"):
455            x0 = self.ds.arr(x0, "code_length")
456        if not hasattr(x1, "units"):
457            x1 = self.ds.arr(x1, "code_length")
458
459        dx0 = ((x0 - self.origin) * self.orienter.unit_vectors[1]).sum()
460        dx1 = ((x1 - self.origin) * self.orienter.unit_vectors[1]).sum()
461        dy0 = ((x0 - self.origin) * self.orienter.unit_vectors[0]).sum()
462        dy1 = ((x1 - self.origin) * self.orienter.unit_vectors[0]).sum()
463        py0 = int(self.resolution[0] * (dx0 / self.width[0]))
464        py1 = int(self.resolution[0] * (dx1 / self.width[0]))
465        px0 = int(self.resolution[1] * (dy0 / self.width[1]))
466        px1 = int(self.resolution[1] * (dy1 / self.width[1]))
467        px = np.array([px0, px1], dtype="int64")
468        py = np.array([py0, py1], dtype="int64")
469        # we flipped it in snapshot to get the orientation correct, so
470        # flip the lines
471        lines(im.d, px, py, np.array([color, color]), flip=1)
472
473    def draw_domain(self, im, alpha=0.3):
474        r"""Draws domain edges on an existing volume rendering.
475
476        Draws a white wireframe on the domain edges.
477
478        Parameters
479        ----------
480        im: Numpy ndarray
481            Existing image that has the same resolution as the Camera,
482            which will be painted by grid lines.
483        alpha : float, optional
484            The alpha value for the wireframe being drawn.  Used to control
485            how bright the lines are with respect to the image.
486            Default : 0.3
487
488        Returns
489        -------
490        nim: Numpy ndarray
491            A new image with the domain lines drawn
492
493        Examples
494        --------
495        >>> im = cam.snapshot()
496        >>> nim = cam.draw_domain(im)
497        >>> write_bitmap(nim, "render_with_domain_boundary.png")
498
499        """
500        # Must normalize the image
501        nim = im.rescale(inline=False)
502        enhance_rgba(nim)
503        nim.add_background_color("black", inline=True)
504
505        self.draw_box(
506            nim,
507            self.ds.domain_left_edge,
508            self.ds.domain_right_edge,
509            color=np.array([1.0, 1.0, 1.0, alpha]),
510        )
511        return nim
512
513    def draw_box(self, im, le, re, color=None):
514        r"""Draws a box on an existing volume rendering.
515
516        Draws a box defined by a left and right edge by modifying an
517        existing volume rendering
518
519        Parameters
520        ----------
521        im: Numpy ndarray
522            Existing image that has the same resolution as the Camera,
523            which will be painted by grid lines.
524        le: Numpy ndarray
525            Left corner of the box
526        re : Numpy ndarray
527            Right corner of the box
528        color : array like, optional
529            Color of the box (r, g, b, a). Defaults to white.
530
531        Returns
532        -------
533        None
534
535        Examples
536        --------
537        >>> im = cam.snapshot()
538        >>> cam.draw_box(im, np.array([0.1, 0.2, 0.3]), np.array([0.5, 0.6, 0.7]))
539        >>> write_bitmap(im, "render_with_box.png")
540
541        """
542
543        if color is None:
544            color = np.array([1.0, 1.0, 1.0, 1.0])
545        corners = get_corners(le, re)
546        order = [0, 1, 1, 2, 2, 3, 3, 0]
547        order += [4, 5, 5, 6, 6, 7, 7, 4]
548        order += [0, 4, 1, 5, 2, 6, 3, 7]
549
550        vertices = np.empty([24, 3])
551        vertices = self.ds.arr(vertices, "code_length")
552        for i in range(3):
553            vertices[:, i] = corners[order, i, ...].ravel(order="F")
554
555        px, py, dz = self.project_to_plane(vertices, res=im.shape[:2])
556
557        # we flipped it in snapshot to get the orientation correct, so
558        # flip the lines
559        lines(
560            im.d,
561            px.d.astype("int64"),
562            py.d.astype("int64"),
563            color.reshape(1, 4),
564            24,
565            flip=1,
566        )
567
568    def look_at(self, new_center, north_vector=None):
569        r"""Change the view direction based on a new focal point.
570
571        This will recalculate all the necessary vectors and vector planes to orient
572        the image plane so that it points at a new location.
573
574        Parameters
575        ----------
576        new_center : array_like
577            The new "center" of the view port -- the focal point for the
578            camera.
579        north_vector : array_like, optional
580            The "up" direction for the plane of rays.  If not specific,
581            calculated automatically.
582        """
583        normal_vector = self.front_center - new_center
584        self.orienter.switch_orientation(
585            normal_vector=normal_vector, north_vector=north_vector
586        )
587
588    def switch_orientation(self, normal_vector=None, north_vector=None):
589        r"""
590        Change the view direction based on any of the orientation parameters.
591
592        This will recalculate all the necessary vectors and vector planes
593        related to an orientable object.
594
595        Parameters
596        ----------
597        normal_vector: array_like, optional
598            The new looking vector.
599        north_vector : array_like, optional
600            The 'up' direction for the plane of rays.  If not specific,
601            calculated automatically.
602        """
603        if north_vector is None:
604            north_vector = self.north_vector
605        if normal_vector is None:
606            normal_vector = self.normal_vector
607        self.orienter._setup_normalized_vectors(normal_vector, north_vector)
608
609    def switch_view(
610        self, normal_vector=None, width=None, center=None, north_vector=None
611    ):
612        r"""Change the view based on any of the view parameters.
613
614        This will recalculate the orientation and width based on any of
615        normal_vector, width, center, and north_vector.
616
617        Parameters
618        ----------
619        normal_vector: array_like, optional
620            The new looking vector.
621        width: float or array of floats, optional
622            The new width.  Can be a single value W -> [W,W,W] or an
623            array [W1, W2, W3] (left/right, top/bottom, front/back)
624        center: array_like, optional
625            Specifies the new center.
626        north_vector : array_like, optional
627            The 'up' direction for the plane of rays.  If not specific,
628            calculated automatically.
629        """
630        if width is None:
631            width = self.width
632        if not is_sequence(width):
633            width = (width, width, width)  # left/right, tom/bottom, front/back
634        self.width = width
635        if center is not None:
636            self.center = center
637        if north_vector is None:
638            north_vector = self.orienter.north_vector
639        if normal_vector is None:
640            normal_vector = self.orienter.normal_vector
641        self.switch_orientation(normal_vector=normal_vector, north_vector=north_vector)
642        self._setup_box_properties(width, self.center, self.orienter.unit_vectors)
643
644    def new_image(self):
645        image = np.zeros(
646            (self.resolution[0], self.resolution[1], 4), dtype="float64", order="C"
647        )
648        return image
649
650    def get_sampler_args(self, image):
651        rotp = np.concatenate(
652            [self.orienter.inv_mat.ravel("F"), self.back_center.ravel()]
653        )
654        args = (
655            np.atleast_3d(rotp),
656            np.atleast_3d(self.box_vectors[2]),
657            self.back_center,
658            (
659                -self.width[0] / 2.0,
660                self.width[0] / 2.0,
661                -self.width[1] / 2.0,
662                self.width[1] / 2.0,
663            ),
664            image,
665            self.orienter.unit_vectors[0],
666            self.orienter.unit_vectors[1],
667            np.array(self.width, dtype="float64"),
668            "KDTree",
669            self.transfer_function,
670            self.sub_samples,
671        )
672        kwargs = {
673            "lens_type": "plane-parallel",
674        }
675        return args, kwargs
676
677    def get_sampler(self, args, kwargs):
678        if self.use_light:
679            if self.light_dir is None:
680                self.set_default_light_dir()
681            temp_dir = np.empty(3, dtype="float64")
682            temp_dir = (
683                self.light_dir[0] * self.orienter.unit_vectors[1]
684                + self.light_dir[1] * self.orienter.unit_vectors[2]
685                + self.light_dir[2] * self.orienter.unit_vectors[0]
686            )
687            if self.light_rgba is None:
688                self.set_default_light_rgba()
689            sampler = LightSourceRenderSampler(
690                *args, light_dir=temp_dir, light_rgba=self.light_rgba, **kwargs
691            )
692        else:
693            sampler = self._sampler_object(*args, **kwargs)
694        return sampler
695
696    def finalize_image(self, image):
697        view_pos = (
698            self.front_center + self.orienter.unit_vectors[2] * 1.0e6 * self.width[2]
699        )
700        image = self.volume.reduce_tree_images(image, view_pos)
701        if not self.transfer_function.grey_opacity:
702            image[:, :, 3] = 1.0
703        return image
704
705    def _render(self, double_check, num_threads, image, sampler):
706        ncells = sum(b.source_mask.size for b in self.volume.bricks)
707        pbar = get_pbar("Ray casting", ncells)
708        total_cells = 0
709        if double_check:
710            for brick in self.volume.bricks:
711                for data in brick.my_data:
712                    if np.any(np.isnan(data)):
713                        raise RuntimeError
714
715        view_pos = (
716            self.front_center + self.orienter.unit_vectors[2] * 1.0e6 * self.width[2]
717        )
718        for brick in self.volume.traverse(view_pos):
719            sampler(brick, num_threads=num_threads)
720            total_cells += brick.source_mask.size
721            pbar.update(total_cells)
722
723        pbar.finish()
724        image = sampler.aimage
725        image = self.finalize_image(image)
726        return image
727
728    @cached_property
729    def _pyplot(self):
730        from matplotlib import pyplot
731
732        return pyplot
733
734    def show_tf(self):
735        if self._tf_figure is None:
736            self._tf_figure = self._pyplot.figure(2)
737            self.transfer_function.show(ax=self._tf_figure.axes)
738        self._pyplot.draw()
739
740    def annotate(self, ax, enhance=True, label_fmt=None):
741        ax.get_xaxis().set_visible(False)
742        ax.get_xaxis().set_ticks([])
743        ax.get_yaxis().set_visible(False)
744        ax.get_yaxis().set_ticks([])
745        cb = self._pyplot.colorbar(
746            ax.images[0], pad=0.0, fraction=0.05, drawedges=True, shrink=0.9
747        )
748        label = self.ds._get_field_info(self.fields[0]).get_label()
749        if self.log_fields[0]:
750            label = r"$\rm{log}\ $" + label
751        self.transfer_function.vert_cbar(ax=cb.ax, label=label, label_fmt=label_fmt)
752
753    def show_mpl(self, im, enhance=True, clear_fig=True):
754        if self._render_figure is None:
755            self._render_figure = self._pyplot.figure(1)
756        if clear_fig:
757            self._render_figure.clf()
758
759        if enhance:
760            nz = im[im > 0.0]
761            nim = im / (nz.mean() + 6.0 * np.std(nz))
762            nim[nim > 1.0] = 1.0
763            nim[nim < 0.0] = 0.0
764            del nz
765        else:
766            nim = im
767        ax = self._pyplot.imshow(nim[:, :, :3] / nim[:, :, :3].max(), origin="upper")
768        return ax
769
770    def draw(self):
771        self._pyplot.draw()
772
773    def save_annotated(
774        self, fn, image, enhance=True, dpi=100, clear_fig=True, label_fmt=None
775    ):
776        """
777        Save an image with the transfer function represented as a colorbar.
778
779        Parameters
780        ----------
781        fn : str
782           The output filename
783        image : ImageArray
784           The image to annotate
785        enhance : bool, optional
786           Enhance the contrast (default: True)
787        dpi : int, optional
788           Dots per inch in the output image (default: 100)
789        clear_fig : bool, optional
790           Reset the figure (through matplotlib.pyplot.clf()) before drawing.  Setting
791           this to false can allow us to overlay the image onto an
792           existing figure
793        label_fmt : str, optional
794           A format specifier (e.g., label_fmt="%.2g") to use in formatting
795           the data values that label the transfer function colorbar.
796
797        """
798        image = image.swapaxes(0, 1)
799        ax = self.show_mpl(image, enhance=enhance, clear_fig=clear_fig)
800        self.annotate(ax.axes, enhance, label_fmt=label_fmt)
801        self._pyplot.savefig(fn, bbox_inches="tight", facecolor="black", dpi=dpi)
802
803    def save_image(self, image, fn=None, clip_ratio=None, transparent=False):
804        if self.comm.rank == 0 and fn is not None:
805            if transparent:
806                image.write_png(
807                    fn, clip_ratio=clip_ratio, rescale=True, background=None
808                )
809            else:
810                image.write_png(
811                    fn, clip_ratio=clip_ratio, rescale=True, background="black"
812                )
813
814    def initialize_source(self):
815        return self.volume.initialize_source(
816            self.fields, self.log_fields, self.no_ghost
817        )
818
819    def get_information(self):
820        info_dict = {
821            "fields": self.fields,
822            "type": self.__class__.__name__,
823            "east_vector": self.orienter.unit_vectors[0],
824            "north_vector": self.orienter.unit_vectors[1],
825            "normal_vector": self.orienter.unit_vectors[2],
826            "width": self.width,
827            "dataset": self.ds.fullpath,
828        }
829        return info_dict
830
831    def snapshot(
832        self,
833        fn=None,
834        clip_ratio=None,
835        double_check=False,
836        num_threads=0,
837        transparent=False,
838    ):
839        r"""Ray-cast the camera.
840
841        This method instructs the camera to take a snapshot -- i.e., call the ray
842        caster -- based on its current settings.
843
844        Parameters
845        ----------
846        fn : string, optional
847            If supplied, the image will be saved out to this before being
848            returned.  Scaling will be to the maximum value.
849        clip_ratio : float, optional
850            If supplied, the 'max_val' argument to write_bitmap will be handed
851            clip_ratio * image.std()
852        double_check : bool, optional
853            Optionally makes sure that the data contains only valid entries.
854            Used for debugging.
855        num_threads : int, optional
856            If supplied, will use 'num_threads' number of OpenMP threads during
857            the rendering.  Defaults to 0, which uses the environment variable
858            OMP_NUM_THREADS.
859        transparent: bool, optional
860            Optionally saves out the 4-channel rgba image, which can appear
861            empty if the alpha channel is low everywhere. Default: False
862
863        Returns
864        -------
865        image : array
866            An (N,M,3) array of the final returned values, in float64 form.
867        """
868        if num_threads is None:
869            num_threads = get_num_threads()
870        image = self.new_image()
871        args, kwargs = self.get_sampler_args(image)
872        sampler = self.get_sampler(args, kwargs)
873        self.initialize_source()
874        image = ImageArray(
875            self._render(double_check, num_threads, image, sampler),
876            info=self.get_information(),
877        )
878
879        # flip it up/down to handle how the png orientation is done
880        image = image[:, ::-1, :]
881        self.save_image(image, fn=fn, clip_ratio=clip_ratio, transparent=transparent)
882        return image
883
884    def show(self, clip_ratio=None):
885        r"""This will take a snapshot and display the resultant image in the
886        IPython notebook.
887
888        If yt is being run from within an IPython session, and it is able to
889        determine this, this function will snapshot and send the resultant
890        image to the IPython notebook for display.
891
892        If yt can't determine if it's inside an IPython session, it will raise
893        YTNotInsideNotebook.
894
895        Parameters
896        ----------
897        clip_ratio : float, optional
898            If supplied, the 'max_val' argument to write_bitmap will be handed
899            clip_ratio * image.std()
900
901        Examples
902        --------
903
904        >>> cam.show()
905
906        """
907        if "__IPYTHON__" in dir(builtins):
908            from IPython.core.displaypub import publish_display_data
909
910            image = self.snapshot()[:, :, :3]
911            if clip_ratio is not None:
912                clip_ratio *= image.std()
913            data = write_bitmap(image, None, clip_ratio)
914            publish_display_data(
915                data={"image/png": data},
916                source="yt.visualization.volume_rendering.camera.Camera",
917            )
918        else:
919            raise YTNotInsideNotebook
920
921    def set_default_light_dir(self):
922        self.light_dir = [1.0, 1.0, 1.0]
923
924    def set_default_light_rgba(self):
925        self.light_rgba = [1.0, 1.0, 1.0, 1.0]
926
927    def zoom(self, factor):
928        r"""Change the distance to the focal point.
929
930        This will zoom the camera in by some `factor` toward the focal point,
931        along the current view direction, modifying the left/right and up/down
932        extents as well.
933
934        Parameters
935        ----------
936        factor : float
937            The factor by which to reduce the distance to the focal point.
938
939
940        Notes
941        -----
942
943        You will need to call snapshot() again to get a new image.
944
945        """
946        self.width /= factor
947        self._setup_box_properties(self.width, self.center, self.orienter.unit_vectors)
948
949    def zoomin(self, final, n_steps, clip_ratio=None):
950        r"""Loop over a zoomin and return snapshots along the way.
951
952        This will yield `n_steps` snapshots until the current view has been
953        zooming in to a final factor of `final`.
954
955        Parameters
956        ----------
957        final : float
958            The zoom factor, with respect to current, desired at the end of the
959            sequence.
960        n_steps : int
961            The number of zoom snapshots to make.
962        clip_ratio : float, optional
963            If supplied, the 'max_val' argument to write_bitmap will be handed
964            clip_ratio * image.std()
965
966
967        Examples
968        --------
969
970        >>> for i, snapshot in enumerate(cam.zoomin(100.0, 10)):
971        ...     iw.write_bitmap(snapshot, "zoom_%04i.png" % i)
972        """
973        f = final ** (1.0 / n_steps)
974        for _ in range(n_steps):
975            self.zoom(f)
976            yield self.snapshot(clip_ratio=clip_ratio)
977
978    def move_to(
979        self, final, n_steps, final_width=None, exponential=False, clip_ratio=None
980    ):
981        r"""Loop over a look_at
982
983        This will yield `n_steps` snapshots until the current view has been
984        moved to a final center of `final` with a final width of final_width.
985
986        Parameters
987        ----------
988        final : array_like
989            The final center to move to after `n_steps`
990        n_steps : int
991            The number of look_at snapshots to make.
992        final_width: float or array_like, optional
993            Specifies the final width after `n_steps`.  Useful for
994            moving and zooming at the same time.
995        exponential : boolean
996            Specifies whether the move/zoom transition follows an
997            exponential path toward the destination or linear
998        clip_ratio : float, optional
999            If supplied, the 'max_val' argument to write_bitmap will be handed
1000            clip_ratio * image.std()
1001
1002        Examples
1003        --------
1004
1005        >>> for i, snapshot in enumerate(cam.move_to([0.2, 0.3, 0.6], 10)):
1006        ...     iw.write_bitmap(snapshot, "move_%04i.png" % i)
1007        """
1008        dW = None
1009        if not isinstance(final, YTArray):
1010            final = self.ds.arr(final, units="code_length")
1011        if exponential:
1012            if final_width is not None:
1013                if not is_sequence(final_width):
1014                    final_width = [final_width, final_width, final_width]
1015                if not isinstance(final_width, YTArray):
1016                    final_width = self.ds.arr(final_width, units="code_length")
1017                    # left/right, top/bottom, front/back
1018                if (self.center == 0.0).all():
1019                    self.center += (final - self.center) / (10.0 * n_steps)
1020                final_zoom = final_width / self.width
1021                dW = final_zoom ** (1.0 / n_steps)
1022            else:
1023                dW = self.ds.arr([1.0, 1.0, 1.0], "code_length")
1024            position_diff = final / self.center
1025            dx = position_diff ** (1.0 / n_steps)
1026        else:
1027            if final_width is not None:
1028                if not is_sequence(final_width):
1029                    final_width = [final_width, final_width, final_width]
1030                if not isinstance(final_width, YTArray):
1031                    final_width = self.ds.arr(final_width, units="code_length")
1032                    # left/right, top/bottom, front/back
1033                dW = (1.0 * final_width - self.width) / n_steps
1034            else:
1035                dW = self.ds.arr([0.0, 0.0, 0.0], "code_length")
1036            dx = (final - self.center) * 1.0 / n_steps
1037        for _ in range(n_steps):
1038            if exponential:
1039                self.switch_view(center=self.center * dx, width=self.width * dW)
1040            else:
1041                self.switch_view(center=self.center + dx, width=self.width + dW)
1042            yield self.snapshot(clip_ratio=clip_ratio)
1043
1044    def rotate(self, theta, rot_vector=None):
1045        r"""Rotate by a given angle
1046
1047        Rotate the view.  If `rot_vector` is None, rotation will occur
1048        around the `north_vector`.
1049
1050        Parameters
1051        ----------
1052        theta : float, in radians
1053             Angle (in radians) by which to rotate the view.
1054        rot_vector  : array_like, optional
1055            Specify the rotation vector around which rotation will
1056            occur.  Defaults to None, which sets rotation around
1057            `north_vector`
1058
1059        Examples
1060        --------
1061
1062        >>> cam.rotate(np.pi / 4)
1063        """
1064        rotate_all = rot_vector is not None
1065        if rot_vector is None:
1066            rot_vector = self.rotation_vector
1067        else:
1068            rot_vector = ensure_numpy_array(rot_vector)
1069            rot_vector = rot_vector / np.linalg.norm(rot_vector)
1070
1071        R = get_rotation_matrix(theta, rot_vector)
1072
1073        normal_vector = self.front_center - self.center
1074        normal_vector = normal_vector / np.sqrt((normal_vector ** 2).sum())
1075
1076        if rotate_all:
1077            self.switch_view(
1078                normal_vector=np.dot(R, normal_vector),
1079                north_vector=np.dot(R, self.orienter.unit_vectors[1]),
1080            )
1081        else:
1082            self.switch_view(normal_vector=np.dot(R, normal_vector))
1083
1084    def pitch(self, theta):
1085        r"""Rotate by a given angle about the horizontal axis
1086
1087        Pitch the view.
1088
1089        Parameters
1090        ----------
1091        theta : float, in radians
1092             Angle (in radians) by which to pitch the view.
1093
1094        Examples
1095        --------
1096
1097        >>> cam.pitch(np.pi / 4)
1098        """
1099        rot_vector = self.orienter.unit_vectors[0]
1100        R = get_rotation_matrix(theta, rot_vector)
1101        self.switch_view(
1102            normal_vector=np.dot(R, self.orienter.unit_vectors[2]),
1103            north_vector=np.dot(R, self.orienter.unit_vectors[1]),
1104        )
1105        if self.orienter.steady_north:
1106            self.orienter.north_vector = self.orienter.unit_vectors[1]
1107
1108    def yaw(self, theta):
1109        r"""Rotate by a given angle about the vertical axis
1110
1111        Yaw the view.
1112
1113        Parameters
1114        ----------
1115        theta : float, in radians
1116             Angle (in radians) by which to yaw the view.
1117
1118        Examples
1119        --------
1120
1121        >>> cam.yaw(np.pi / 4)
1122        """
1123        rot_vector = self.orienter.unit_vectors[1]
1124        R = get_rotation_matrix(theta, rot_vector)
1125        self.switch_view(normal_vector=np.dot(R, self.orienter.unit_vectors[2]))
1126
1127    def roll(self, theta):
1128        r"""Rotate by a given angle about the view normal axis
1129
1130        Roll the view.
1131
1132        Parameters
1133        ----------
1134        theta : float, in radians
1135             Angle (in radians) by which to roll the view.
1136
1137        Examples
1138        --------
1139
1140        >>> cam.roll(np.pi / 4)
1141        """
1142        rot_vector = self.orienter.unit_vectors[2]
1143        R = get_rotation_matrix(theta, rot_vector)
1144        self.switch_view(
1145            normal_vector=np.dot(R, self.orienter.unit_vectors[2]),
1146            north_vector=np.dot(R, self.orienter.unit_vectors[1]),
1147        )
1148        if self.orienter.steady_north:
1149            self.orienter.north_vector = np.dot(R, self.orienter.north_vector)
1150
1151    def rotation(self, theta, n_steps, rot_vector=None, clip_ratio=None):
1152        r"""Loop over rotate, creating a rotation
1153
1154        This will yield `n_steps` snapshots until the current view has been
1155        rotated by an angle `theta`
1156
1157        Parameters
1158        ----------
1159        theta : float, in radians
1160            Angle (in radians) by which to rotate the view.
1161        n_steps : int
1162            The number of look_at snapshots to make.
1163        rot_vector  : array_like, optional
1164            Specify the rotation vector around which rotation will
1165            occur.  Defaults to None, which sets rotation around the
1166            original `north_vector`
1167        clip_ratio : float, optional
1168            If supplied, the 'max_val' argument to write_bitmap will be handed
1169            clip_ratio * image.std()
1170
1171        Examples
1172        --------
1173
1174        >>> for i, snapshot in enumerate(cam.rotation(np.pi, 10)):
1175        ...     iw.write_bitmap(snapshot, "rotation_%04i.png" % i)
1176        """
1177
1178        dtheta = (1.0 * theta) / n_steps
1179        for _ in range(n_steps):
1180            self.rotate(dtheta, rot_vector=rot_vector)
1181            yield self.snapshot(clip_ratio=clip_ratio)
1182
1183
1184data_object_registry["camera"] = Camera
1185
1186
1187class InteractiveCamera(Camera):
1188    frames = []
1189
1190    def snapshot(self, fn=None, clip_ratio=None):
1191        self._pyplot.figure(2)
1192        self.transfer_function.show()
1193        self._pyplot.draw()
1194        im = Camera.snapshot(self, fn, clip_ratio)
1195        self._pyplot.figure(1)
1196        self._pyplot.imshow(im / im.max())
1197        self._pyplot.draw()
1198        self.frames.append(im)
1199
1200    def rotation(self, theta, n_steps, rot_vector=None):
1201        for frame in Camera.rotation(self, theta, n_steps, rot_vector):
1202            if frame is not None:
1203                self.frames.append(frame)
1204
1205    def zoomin(self, final, n_steps):
1206        for frame in Camera.zoomin(self, final, n_steps):
1207            if frame is not None:
1208                self.frames.append(frame)
1209
1210    def clear_frames(self):
1211        del self.frames
1212        self.frames = []
1213
1214    def save(self, fn):
1215        self._pyplot.savefig(fn, bbox_inches="tight", facecolor="black")
1216
1217    def save_frames(self, basename, clip_ratio=None):
1218        for i, frame in enumerate(self.frames):
1219            fn = basename + "_%04i.png" % i
1220            if clip_ratio is not None:
1221                write_bitmap(frame, fn, clip_ratio * frame.std())
1222            else:
1223                write_bitmap(frame, fn)
1224
1225
1226data_object_registry["interactive_camera"] = InteractiveCamera
1227
1228
1229class PerspectiveCamera(Camera):
1230    r"""A viewpoint into a volume, for perspective volume rendering.
1231
1232    The camera represents the eye of an observer, which will be used to
1233    generate ray-cast volume renderings of the domain. The rays start from
1234    the camera and end on the image plane, which generates a perspective
1235    view.
1236
1237    Note: at the moment, this results in a left-handed coordinate
1238    system view
1239
1240    Parameters
1241    ----------
1242    center : array_like
1243        The location of the camera
1244    normal_vector : array_like
1245        The vector from the camera position to the center of the image plane
1246    width : float or list of floats
1247        width[0] and width[1] give the width and height of the image plane, and
1248        width[2] gives the depth of the image plane (distance between the camera
1249        and the center of the image plane).
1250        The view angles thus become:
1251        2 * arctan(0.5 * width[0] / width[2]) in horizontal direction
1252        2 * arctan(0.5 * width[1] / width[2]) in vertical direction
1253    (The following parameters are identical with the definitions in Camera class)
1254    resolution : int or list of ints
1255        The number of pixels in each direction.
1256    transfer_function : `yt.visualization.volume_rendering.TransferFunction`
1257        The transfer function used to map values to colors in an image.  If
1258        not specified, defaults to a ProjectionTransferFunction.
1259    north_vector : array_like, optional
1260        The 'up' direction for the plane of rays.  If not specific, calculated
1261        automatically.
1262    steady_north : bool, optional
1263        Boolean to control whether to normalize the north_vector
1264        by subtracting off the dot product of it and the normal
1265        vector.  Makes it easier to do rotations along a single
1266        axis.  If north_vector is specified, is switched to
1267        True. Default: False
1268    volume : `yt.extensions.volume_rendering.AMRKDTree`, optional
1269        The volume to ray cast through.  Can be specified for finer-grained
1270        control, but otherwise will be automatically generated.
1271    fields : list of fields, optional
1272        This is the list of fields we want to volume render; defaults to
1273        Density.
1274    log_fields : list of bool, optional
1275        Whether we should take the log of the fields before supplying them to
1276        the volume rendering mechanism.
1277    sub_samples : int, optional
1278        The number of samples to take inside every cell per ray.
1279    ds : ~yt.data_objects.static_output.Dataset
1280        For now, this is a require parameter!  But in the future it will become
1281        optional.  This is the dataset to volume render.
1282    use_kd: bool, optional
1283        Specifies whether or not to use a kd-Tree framework for
1284        the Homogenized Volume and ray-casting.  Default to True.
1285    max_level: int, optional
1286        Specifies the maximum level to be rendered.  Also
1287        specifies the maximum level used in the kd-Tree
1288        construction.  Defaults to None (all levels), and only
1289        applies if use_kd=True.
1290    no_ghost: bool, optional
1291        Optimization option.  If True, homogenized bricks will
1292        extrapolate out from grid instead of interpolating from
1293        ghost zones that have to first be calculated.  This can
1294        lead to large speed improvements, but at a loss of
1295        accuracy/smoothness in resulting image.  The effects are
1296        less notable when the transfer function is smooth and
1297        broad. Default: True
1298    data_source: data container, optional
1299        Optionally specify an arbitrary data source to the volume rendering.
1300        All cells not included in the data source will be ignored during ray
1301        casting. By default this will get set to ds.all_data().
1302
1303    """
1304
1305    def __init__(self, *args, **kwargs):
1306        Camera.__init__(self, *args, **kwargs)
1307
1308    def get_sampler_args(self, image):
1309        east_vec = self.orienter.unit_vectors[0].reshape(3, 1)
1310        north_vec = self.orienter.unit_vectors[1].reshape(3, 1)
1311
1312        px = np.linspace(-0.5, 0.5, self.resolution[0])[np.newaxis, :]
1313        py = np.linspace(-0.5, 0.5, self.resolution[1])[np.newaxis, :]
1314
1315        sample_x = self.width[0] * np.array(east_vec * px).transpose()
1316        sample_y = self.width[1] * np.array(north_vec * py).transpose()
1317
1318        vectors = np.zeros(
1319            (self.resolution[0], self.resolution[1], 3), dtype="float64", order="C"
1320        )
1321
1322        sample_x = np.repeat(
1323            sample_x.reshape(self.resolution[0], 1, 3), self.resolution[1], axis=1
1324        )
1325        sample_y = np.repeat(
1326            sample_y.reshape(1, self.resolution[1], 3), self.resolution[0], axis=0
1327        )
1328
1329        normal_vec = np.zeros(
1330            (self.resolution[0], self.resolution[1], 3), dtype="float64", order="C"
1331        )
1332        normal_vec[:, :, 0] = self.orienter.unit_vectors[2, 0]
1333        normal_vec[:, :, 1] = self.orienter.unit_vectors[2, 1]
1334        normal_vec[:, :, 2] = self.orienter.unit_vectors[2, 2]
1335
1336        vectors = sample_x + sample_y + normal_vec * self.width[2]
1337
1338        positions = np.zeros(
1339            (self.resolution[0], self.resolution[1], 3), dtype="float64", order="C"
1340        )
1341        positions[:, :, 0] = self.center[0]
1342        positions[:, :, 1] = self.center[1]
1343        positions[:, :, 2] = self.center[2]
1344
1345        positions = self.ds.arr(positions, units="code_length")
1346
1347        dummy = np.ones(3, dtype="float64")
1348        image.shape = (self.resolution[0], self.resolution[1], 4)
1349
1350        args = (
1351            positions,
1352            vectors,
1353            self.back_center,
1354            (0.0, 1.0, 0.0, 1.0),
1355            image,
1356            dummy,
1357            dummy,
1358            np.zeros(3, dtype="float64"),
1359            "KDTree",
1360            self.transfer_function,
1361            self.sub_samples,
1362        )
1363        kwargs = {
1364            "lens_type": "perspective",
1365        }
1366        return args, kwargs
1367
1368    def _render(self, double_check, num_threads, image, sampler):
1369        ncells = sum(b.source_mask.size for b in self.volume.bricks)
1370        pbar = get_pbar("Ray casting", ncells)
1371        total_cells = 0
1372        if double_check:
1373            for brick in self.volume.bricks:
1374                for data in brick.my_data:
1375                    if np.any(np.isnan(data)):
1376                        raise RuntimeError
1377
1378        for brick in self.volume.traverse(self.front_center):
1379            sampler(brick, num_threads=num_threads)
1380            total_cells += brick.source_mask.size
1381            pbar.update(total_cells)
1382
1383        pbar.finish()
1384        image = self.finalize_image(sampler.aimage)
1385        return image
1386
1387    def finalize_image(self, image):
1388        view_pos = self.front_center
1389        image.shape = self.resolution[0], self.resolution[1], 4
1390        image = self.volume.reduce_tree_images(image, view_pos)
1391        if not self.transfer_function.grey_opacity:
1392            image[:, :, 3] = 1.0
1393        return image
1394
1395    def project_to_plane(self, pos, res=None):
1396        if res is None:
1397            res = self.resolution
1398        sight_vector = pos - self.center
1399        pos1 = sight_vector
1400        for i in range(0, sight_vector.shape[0]):
1401            sight_vector_norm = np.sqrt(np.dot(sight_vector[i], sight_vector[i]))
1402            sight_vector[i] = sight_vector[i] / sight_vector_norm
1403        sight_vector = self.ds.arr(sight_vector.value, units="dimensionless")
1404        sight_center = self.center + self.width[2] * self.orienter.unit_vectors[2]
1405
1406        for i in range(0, sight_vector.shape[0]):
1407            sight_angle_cos = np.dot(sight_vector[i], self.orienter.unit_vectors[2])
1408            if np.arccos(sight_angle_cos) < 0.5 * np.pi:
1409                sight_length = self.width[2] / sight_angle_cos
1410            else:
1411                # The corner is on the backwards, then put it outside of the
1412                # image It can not be simply removed because it may connect to
1413                # other corner within the image, which produces visible domain
1414                # boundary line
1415                sight_length = np.sqrt(
1416                    self.width[0] ** 2 + self.width[1] ** 2
1417                ) / np.sqrt(1 - sight_angle_cos ** 2)
1418            pos1[i] = self.center + sight_length * sight_vector[i]
1419
1420        dx = np.dot(pos1 - sight_center, self.orienter.unit_vectors[0])
1421        dy = np.dot(pos1 - sight_center, self.orienter.unit_vectors[1])
1422        dz = np.dot(pos1 - sight_center, self.orienter.unit_vectors[2])
1423        # Transpose into image coords.
1424        px = (res[0] * 0.5 + res[0] / self.width[0] * dx).astype("int")
1425        py = (res[1] * 0.5 + res[1] / self.width[1] * dy).astype("int")
1426        return px, py, dz
1427
1428    def yaw(self, theta, rot_center):
1429        r"""Rotate by a given angle about the vertical axis through the
1430        point center.  This is accomplished by rotating the
1431        focal point and then setting the looking vector to point
1432        to the center.
1433
1434        Yaw the view.
1435
1436        Parameters
1437        ----------
1438        theta : float, in radians
1439             Angle (in radians) by which to yaw the view.
1440
1441        rot_center : a tuple (x, y, z)
1442             The point to rotate about
1443
1444        Examples
1445        --------
1446
1447        >>> cam.yaw(np.pi / 4, (0.0, 0.0, 0.0))
1448        """
1449
1450        rot_vector = self.orienter.unit_vectors[1]
1451
1452        focal_point = self.center - rot_center
1453        R = get_rotation_matrix(theta, rot_vector)
1454        focal_point = np.dot(R, focal_point) + rot_center
1455
1456        normal_vector = rot_center - focal_point
1457        normal_vector = normal_vector / np.sqrt((normal_vector ** 2).sum())
1458
1459        self.switch_view(normal_vector=normal_vector, center=focal_point)
1460
1461
1462data_object_registry["perspective_camera"] = PerspectiveCamera
1463
1464
1465def corners(left_edge, right_edge):
1466    return np.array(
1467        [
1468            [left_edge[:, 0], left_edge[:, 1], left_edge[:, 2]],
1469            [right_edge[:, 0], left_edge[:, 1], left_edge[:, 2]],
1470            [right_edge[:, 0], right_edge[:, 1], left_edge[:, 2]],
1471            [right_edge[:, 0], right_edge[:, 1], right_edge[:, 2]],
1472            [left_edge[:, 0], right_edge[:, 1], right_edge[:, 2]],
1473            [left_edge[:, 0], left_edge[:, 1], right_edge[:, 2]],
1474            [right_edge[:, 0], left_edge[:, 1], right_edge[:, 2]],
1475            [left_edge[:, 0], right_edge[:, 1], left_edge[:, 2]],
1476        ],
1477        dtype="float64",
1478    )
1479
1480
1481class HEALpixCamera(Camera):
1482
1483    _sampler_object = None
1484
1485    def __init__(
1486        self,
1487        center,
1488        radius,
1489        nside,
1490        transfer_function=None,
1491        fields=None,
1492        sub_samples=5,
1493        log_fields=None,
1494        volume=None,
1495        ds=None,
1496        use_kd=True,
1497        no_ghost=False,
1498        use_light=False,
1499        inner_radius=10,
1500    ):
1501        mylog.error("I am sorry, HEALpix Camera does not work yet in 3.0")
1502        raise NotImplementedError
1503
1504    def new_image(self):
1505        image = np.zeros((12 * self.nside ** 2, 1, 4), dtype="float64", order="C")
1506        return image
1507
1508    def get_sampler_args(self, image):
1509        nv = 12 * self.nside ** 2
1510        vs = arr_pix2vec_nest(self.nside, np.arange(nv))
1511        vs.shape = (nv, 1, 3)
1512        vs += 1e-8
1513        uv = np.ones(3, dtype="float64")
1514        positions = np.ones((nv, 1, 3), dtype="float64") * self.center
1515        dx = min(g.dds.min() for g in self.ds.index.find_point(self.center)[0])
1516        positions += self.inner_radius * dx * vs
1517        vs *= self.radius
1518        args = (
1519            positions,
1520            vs,
1521            self.center,
1522            (0.0, 1.0, 0.0, 1.0),
1523            image,
1524            uv,
1525            uv,
1526            np.zeros(3, dtype="float64"),
1527            "KDTree",
1528        )
1529        if self._needs_tf:
1530            args += (self.transfer_function,)
1531        args += (self.sub_samples,)
1532
1533        return args, {}
1534
1535    def _render(self, double_check, num_threads, image, sampler):
1536        pbar = get_pbar(
1537            "Ray casting", (self.volume.brick_dimensions + 1).prod(axis=-1).sum()
1538        )
1539        total_cells = 0
1540        if double_check:
1541            for brick in self.volume.bricks:
1542                for data in brick.my_data:
1543                    if np.any(np.isnan(data)):
1544                        raise RuntimeError
1545
1546        view_pos = self.center
1547        for brick in self.volume.traverse(view_pos):
1548            sampler(brick, num_threads=num_threads)
1549            total_cells += np.prod(brick.my_data[0].shape)
1550            pbar.update(total_cells)
1551
1552        pbar.finish()
1553        image = sampler.aimage
1554
1555        self.finalize_image(image)
1556
1557        return image
1558
1559    def finalize_image(self, image):
1560        view_pos = self.center
1561        image = self.volume.reduce_tree_images(image, view_pos)
1562        return image
1563
1564    def get_information(self):
1565        info_dict = {
1566            "fields": self.fields,
1567            "type": self.__class__.__name__,
1568            "center": self.center,
1569            "radius": self.radius,
1570            "dataset": self.ds.fullpath,
1571        }
1572        return info_dict
1573
1574    def snapshot(
1575        self,
1576        fn=None,
1577        clip_ratio=None,
1578        double_check=False,
1579        num_threads=0,
1580        clim=None,
1581        label=None,
1582    ):
1583        r"""Ray-cast the camera.
1584
1585        This method instructs the camera to take a snapshot -- i.e., call the ray
1586        caster -- based on its current settings.
1587
1588        Parameters
1589        ----------
1590        fn : string, optional
1591            If supplied, the image will be saved out to this before being
1592            returned.  Scaling will be to the maximum value.
1593        clip_ratio : float, optional
1594            If supplied, the 'max_val' argument to write_bitmap will be handed
1595            clip_ratio * image.std()
1596
1597        Returns
1598        -------
1599        image : array
1600            An (N,M,3) array of the final returned values, in float64 form.
1601        """
1602        if num_threads is None:
1603            num_threads = get_num_threads()
1604        image = self.new_image()
1605        args, kwargs = self.get_sampler_args(image)
1606        sampler = self.get_sampler(args, kwargs)
1607        self.volume.initialize_source()
1608        image = ImageArray(
1609            self._render(double_check, num_threads, image, sampler),
1610            info=self.get_information(),
1611        )
1612        self.save_image(image, fn=fn, clim=clim, label=label)
1613        return image
1614
1615    def save_image(self, image, fn=None, clim=None, label=None):
1616        if self.comm.rank == 0 and fn is not None:
1617            # This assumes Density; this is a relatively safe assumption.
1618            if label is None:
1619                label = f"Projected {self.fields[0]}"
1620            if clim is not None:
1621                cmin, cmax = clim
1622            else:
1623                cmin = cmax = None
1624            plot_allsky_healpix(
1625                image[:, 0, 0], self.nside, fn, label, cmin=cmin, cmax=cmax
1626            )
1627
1628
1629class StereoPairCamera(Camera):
1630    def __init__(self, original_camera, relative_separation=0.005):
1631        ParallelAnalysisInterface.__init__(self)
1632        self.original_camera = original_camera
1633        self.relative_separation = relative_separation
1634
1635    def split(self):
1636        oc = self.original_camera
1637        uv = oc.orienter.unit_vectors
1638        c = oc.center
1639        fc = oc.front_center
1640        wx, wy, wz = oc.width
1641        left_normal = fc + uv[1] * 0.5 * self.relative_separation * wx - c
1642        right_normal = fc - uv[1] * 0.5 * self.relative_separation * wx - c
1643        left_camera = Camera(
1644            c,
1645            left_normal,
1646            oc.width,
1647            oc.resolution,
1648            oc.transfer_function,
1649            north_vector=uv[0],
1650            volume=oc.volume,
1651            fields=oc.fields,
1652            log_fields=oc.log_fields,
1653            sub_samples=oc.sub_samples,
1654            ds=oc.ds,
1655        )
1656        right_camera = Camera(
1657            c,
1658            right_normal,
1659            oc.width,
1660            oc.resolution,
1661            oc.transfer_function,
1662            north_vector=uv[0],
1663            volume=oc.volume,
1664            fields=oc.fields,
1665            log_fields=oc.log_fields,
1666            sub_samples=oc.sub_samples,
1667            ds=oc.ds,
1668        )
1669        return (left_camera, right_camera)
1670
1671
1672class FisheyeCamera(Camera):
1673    def __init__(
1674        self,
1675        center,
1676        radius,
1677        fov,
1678        resolution,
1679        transfer_function=None,
1680        fields=None,
1681        sub_samples=5,
1682        log_fields=None,
1683        volume=None,
1684        ds=None,
1685        no_ghost=False,
1686        rotation=None,
1687        use_light=False,
1688    ):
1689        ParallelAnalysisInterface.__init__(self)
1690        self.use_light = use_light
1691        self.light_dir = None
1692        self.light_rgba = None
1693        if rotation is None:
1694            rotation = np.eye(3)
1695        self.rotation_matrix = rotation
1696        self.no_ghost = no_ghost
1697        if ds is not None:
1698            self.ds = ds
1699        self.center = np.array(center, dtype="float64")
1700        self.radius = radius
1701        self.fov = fov
1702        if is_sequence(resolution):
1703            raise RuntimeError("Resolution must be a single int")
1704        self.resolution = resolution
1705        if transfer_function is None:
1706            transfer_function = ProjectionTransferFunction()
1707        self.transfer_function = transfer_function
1708        if fields is None:
1709            fields = [("gas", "density")]
1710        dd = self.ds.all_data()
1711        fields = dd._determine_fields(fields)
1712        self.fields = fields
1713        if log_fields is None:
1714            log_fields = [self.ds._get_field_info(*f).take_log for f in fields]
1715        self.log_fields = log_fields
1716        self.sub_samples = sub_samples
1717        if volume is None:
1718            volume = AMRKDTree(self.ds)
1719            volume.set_fields(fields, log_fields, no_ghost)
1720        self.volume = volume
1721
1722    def get_information(self):
1723        return {}
1724
1725    def new_image(self):
1726        image = np.zeros((self.resolution ** 2, 1, 4), dtype="float64", order="C")
1727        return image
1728
1729    def get_sampler_args(self, image):
1730        vp = arr_fisheye_vectors(self.resolution, self.fov)
1731        vp.shape = (self.resolution ** 2, 1, 3)
1732        vp2 = vp.copy()
1733        for i in range(3):
1734            vp[:, :, i] = (vp2 * self.rotation_matrix[:, i]).sum(axis=2)
1735        del vp2
1736        vp *= self.radius
1737        uv = np.ones(3, dtype="float64")
1738        positions = np.ones((self.resolution ** 2, 1, 3), dtype="float64") * self.center
1739
1740        args = (
1741            positions,
1742            vp,
1743            self.center,
1744            (0.0, 1.0, 0.0, 1.0),
1745            image,
1746            uv,
1747            uv,
1748            np.zeros(3, dtype="float64"),
1749            "KDTree",
1750            self.transfer_function,
1751            self.sub_samples,
1752        )
1753        return args, {}
1754
1755    def finalize_image(self, image):
1756        image.shape = self.resolution, self.resolution, 4
1757
1758    def _render(self, double_check, num_threads, image, sampler):
1759        pbar = get_pbar(
1760            "Ray casting", (self.volume.brick_dimensions + 1).prod(axis=-1).sum()
1761        )
1762        total_cells = 0
1763        if double_check:
1764            for brick in self.volume.bricks:
1765                for data in brick.my_data:
1766                    if np.any(np.isnan(data)):
1767                        raise RuntimeError
1768
1769        view_pos = self.center
1770        for brick in self.volume.traverse(view_pos):
1771            sampler(brick, num_threads=num_threads)
1772            total_cells += np.prod(brick.my_data[0].shape)
1773            pbar.update(total_cells)
1774
1775        pbar.finish()
1776        image = sampler.aimage
1777
1778        self.finalize_image(image)
1779
1780        return image
1781
1782
1783class MosaicCamera(Camera):
1784    def __init__(
1785        self,
1786        center,
1787        normal_vector,
1788        width,
1789        resolution,
1790        transfer_function=None,
1791        north_vector=None,
1792        steady_north=False,
1793        volume=None,
1794        fields=None,
1795        log_fields=None,
1796        sub_samples=5,
1797        ds=None,
1798        use_kd=True,
1799        l_max=None,
1800        no_ghost=True,
1801        tree_type="domain",
1802        expand_factor=1.0,
1803        le=None,
1804        re=None,
1805        nimx=1,
1806        nimy=1,
1807        procs_per_wg=None,
1808        preload=True,
1809        use_light=False,
1810    ):
1811
1812        ParallelAnalysisInterface.__init__(self)
1813
1814        self.procs_per_wg = procs_per_wg
1815        if ds is not None:
1816            self.ds = ds
1817        if not is_sequence(resolution):
1818            resolution = (int(resolution / nimx), int(resolution / nimy))
1819        self.resolution = resolution
1820        self.nimx = nimx
1821        self.nimy = nimy
1822        self.sub_samples = sub_samples
1823        if not is_sequence(width):
1824            width = (width, width, width)  # front/back, left/right, top/bottom
1825        self.width = np.array([width[0], width[1], width[2]])
1826        self.center = center
1827        self.steady_north = steady_north
1828        self.expand_factor = expand_factor
1829        # This seems to be necessary for now.  Not sure what goes wrong when not true.
1830        if north_vector is not None:
1831            self.steady_north = True
1832        self.north_vector = north_vector
1833        self.normal_vector = normal_vector
1834        if fields is None:
1835            fields = [("gas", "density")]
1836        self.fields = fields
1837        if transfer_function is None:
1838            transfer_function = ProjectionTransferFunction()
1839        self.transfer_function = transfer_function
1840        self.log_fields = log_fields
1841        self.use_kd = use_kd
1842        self.l_max = l_max
1843        self.no_ghost = no_ghost
1844        self.preload = preload
1845
1846        self.use_light = use_light
1847        self.light_dir = None
1848        self.light_rgba = None
1849        self.le = le
1850        self.re = re
1851        self.width[0] /= self.nimx
1852        self.width[1] /= self.nimy
1853
1854        self.orienter = Orientation(
1855            normal_vector, north_vector=north_vector, steady_north=steady_north
1856        )
1857        self.rotation_vector = self.orienter.north_vector
1858        # self._setup_box_properties(width, center, self.orienter.unit_vectors)
1859
1860        if self.no_ghost:
1861            mylog.warning(
1862                "no_ghost is currently True (default). "
1863                "This may lead to artifacts at grid boundaries."
1864            )
1865        self.tree_type = tree_type
1866        self.volume = volume
1867
1868        # self.cameras = np.empty(self.nimx*self.nimy)
1869
1870    def build_volume(
1871        self, volume, fields, log_fields, l_max, no_ghost, tree_type, le, re
1872    ):
1873        if volume is None:
1874            if self.use_kd:
1875                raise NotImplementedError
1876            volume = AMRKDTree(
1877                self.ds,
1878                l_max=l_max,
1879                fields=self.fields,
1880                no_ghost=no_ghost,
1881                tree_type=tree_type,
1882                log_fields=log_fields,
1883                le=le,
1884                re=re,
1885            )
1886        else:
1887            self.use_kd = isinstance(volume, AMRKDTree)
1888        return volume
1889
1890    def new_image(self):
1891        image = np.zeros(
1892            (self.resolution[0], self.resolution[1], 4), dtype="float64", order="C"
1893        )
1894        return image
1895
1896    def _setup_box_properties(self, width, center, unit_vectors):
1897        owidth = deepcopy(width)
1898        self.width = width
1899        self.origin = (
1900            self.center
1901            - 0.5 * self.nimx * self.width[0] * self.orienter.unit_vectors[0]
1902            - 0.5 * self.nimy * self.width[1] * self.orienter.unit_vectors[1]
1903            - 0.5 * self.width[2] * self.orienter.unit_vectors[2]
1904        )
1905        dx = self.width[0]
1906        dy = self.width[1]
1907        offi = self.imi + 0.5
1908        offj = self.imj + 0.5
1909        mylog.info("Mosaic offset: %f %f", offi, offj)
1910        global_center = self.center
1911        self.center = self.origin
1912        self.center += offi * dx * self.orienter.unit_vectors[0]
1913        self.center += offj * dy * self.orienter.unit_vectors[1]
1914
1915        self.box_vectors = np.array(
1916            [
1917                self.orienter.unit_vectors[0] * dx * self.nimx,
1918                self.orienter.unit_vectors[1] * dy * self.nimy,
1919                self.orienter.unit_vectors[2] * self.width[2],
1920            ]
1921        )
1922        self.back_center = (
1923            self.center - 0.5 * self.width[0] * self.orienter.unit_vectors[2]
1924        )
1925        self.front_center = (
1926            self.center + 0.5 * self.width[0] * self.orienter.unit_vectors[2]
1927        )
1928        self.center = global_center
1929        self.width = owidth
1930
1931    def snapshot(self, fn=None, clip_ratio=None, double_check=False, num_threads=0):
1932
1933        my_storage = {}
1934        offx, offy = np.meshgrid(range(self.nimx), range(self.nimy))
1935        offxy = zip(offx.ravel(), offy.ravel())
1936
1937        for sto, xy in parallel_objects(
1938            offxy, self.procs_per_wg, storage=my_storage, dynamic=True
1939        ):
1940            self.volume = self.build_volume(
1941                self.volume,
1942                self.fields,
1943                self.log_fields,
1944                self.l_max,
1945                self.no_ghost,
1946                self.tree_type,
1947                self.le,
1948                self.re,
1949            )
1950            self.initialize_source()
1951
1952            self.imi, self.imj = xy
1953            mylog.debug("Working on: %i %i", self.imi, self.imj)
1954            self._setup_box_properties(
1955                self.width, self.center, self.orienter.unit_vectors
1956            )
1957            image = self.new_image()
1958            args, kwargs = self.get_sampler_args(image)
1959            sampler = self.get_sampler(args, kwargs)
1960            image = self._render(double_check, num_threads, image, sampler)
1961            sto.id = self.imj * self.nimx + self.imi
1962            sto.result = image
1963        image = self.reduce_images(my_storage)
1964        self.save_image(image, fn=fn, clip_ratio=clip_ratio)
1965        return image
1966
1967    def reduce_images(self, im_dict):
1968        final_image = 0
1969        if self.comm.rank == 0:
1970            offx, offy = np.meshgrid(range(self.nimx), range(self.nimy))
1971            offxy = zip(offx.ravel(), offy.ravel())
1972            nx, ny = self.resolution
1973            final_image = np.empty(
1974                (nx * self.nimx, ny * self.nimy, 4), dtype="float64", order="C"
1975            )
1976            for xy in offxy:
1977                i, j = xy
1978                ind = j * self.nimx + i
1979                final_image[i * nx : (i + 1) * nx, j * ny : (j + 1) * ny, :] = im_dict[
1980                    ind
1981                ]
1982        return final_image
1983
1984
1985data_object_registry["mosaic_camera"] = MosaicCamera
1986
1987
1988def plot_allsky_healpix(
1989    image,
1990    nside,
1991    fn,
1992    label="",
1993    rotation=None,
1994    take_log=True,
1995    resolution=512,
1996    cmin=None,
1997    cmax=None,
1998):
1999    import matplotlib.backends.backend_agg
2000    import matplotlib.figure
2001
2002    if rotation is None:
2003        rotation = np.eye(3).astype("float64")
2004
2005    img, count = pixelize_healpix(nside, image, resolution, resolution, rotation)
2006
2007    fig = matplotlib.figure.Figure((10, 5))
2008    ax = fig.add_subplot(1, 1, 1, projection="aitoff")
2009    if take_log:
2010        func = np.log10
2011    else:
2012
2013        def _identity(x):
2014            return x
2015
2016        func = _identity
2017    implot = ax.imshow(
2018        func(img),
2019        extent=(-np.pi, np.pi, -np.pi / 2, np.pi / 2),
2020        clip_on=False,
2021        aspect=0.5,
2022        vmin=cmin,
2023        vmax=cmax,
2024    )
2025    cb = fig.colorbar(implot, orientation="horizontal")
2026    cb.set_label(label)
2027    ax.xaxis.set_ticks(())
2028    ax.yaxis.set_ticks(())
2029    canvas = matplotlib.backends.backend_agg.FigureCanvasAgg(fig)
2030    canvas.print_figure(fn)
2031    return img, count
2032
2033
2034class ProjectionCamera(Camera):
2035    def __init__(
2036        self,
2037        center,
2038        normal_vector,
2039        width,
2040        resolution,
2041        field,
2042        weight=None,
2043        volume=None,
2044        no_ghost=False,
2045        north_vector=None,
2046        ds=None,
2047        interpolated=False,
2048        method="integrate",
2049    ):
2050
2051        if not interpolated:
2052            volume = 1
2053
2054        self.interpolated = interpolated
2055        self.field = field
2056        self.weight = weight
2057        self.resolution = resolution
2058        self.method = method
2059
2060        fields = [field]
2061        if self.weight is not None:
2062            # This is a temporary field, which we will remove at the end
2063            # it is given a unique name to avoid conflicting with other
2064            # class instances
2065            self.weightfield = ("index", "temp_weightfield_%u" % (id(self),))
2066
2067            def _make_wf(f, w):
2068                def temp_weightfield(a, b):
2069                    tr = b[f].astype("float64") * b[w]
2070                    return b.apply_units(tr, a.units)
2071
2072                return temp_weightfield
2073
2074            ds.field_info.add_field(
2075                self.weightfield, function=_make_wf(self.field, self.weight)
2076            )
2077            # Now we have to tell the dataset to add it and to calculate
2078            # its dependencies..
2079            deps, _ = ds.field_info.check_derived_fields([self.weightfield])
2080            ds.field_dependencies.update(deps)
2081            fields = [self.weightfield, self.weight]
2082
2083        self.fields = fields
2084        self.log_fields = [False] * len(self.fields)
2085        Camera.__init__(
2086            self,
2087            center,
2088            normal_vector,
2089            width,
2090            resolution,
2091            None,
2092            fields=fields,
2093            ds=ds,
2094            volume=volume,
2095            log_fields=self.log_fields,
2096            north_vector=north_vector,
2097            no_ghost=no_ghost,
2098        )
2099
2100    # this would be better in an __exit__ function, but that would require
2101    # changes in code that uses this class
2102    def __del__(self):
2103        if hasattr(self, "weightfield") and hasattr(self, "ds"):
2104            try:
2105                self.ds.field_info.pop(self.weightfield)
2106                self.ds.field_dependencies.pop(self.weightfield)
2107            except KeyError:
2108                pass
2109        try:
2110            Camera.__del__(self)
2111        except AttributeError:
2112            pass
2113
2114    def get_sampler(self, args, kwargs):
2115        if self.interpolated:
2116            sampler = InterpolatedProjectionSampler(*args, **kwargs)
2117        else:
2118            sampler = ProjectionSampler(*args, **kwargs)
2119        return sampler
2120
2121    def initialize_source(self):
2122        if self.interpolated:
2123            Camera.initialize_source(self)
2124        else:
2125            pass
2126
2127    def get_sampler_args(self, image):
2128        rotp = np.concatenate(
2129            [self.orienter.inv_mat.ravel("F"), self.back_center.ravel()]
2130        )
2131        args = (
2132            np.atleast_3d(rotp),
2133            np.atleast_3d(self.box_vectors[2]),
2134            self.back_center,
2135            (
2136                -self.width[0] / 2.0,
2137                self.width[0] / 2.0,
2138                -self.width[1] / 2.0,
2139                self.width[1] / 2.0,
2140            ),
2141            image,
2142            self.orienter.unit_vectors[0],
2143            self.orienter.unit_vectors[1],
2144            np.array(self.width, dtype="float64"),
2145            "KDTree",
2146            self.sub_samples,
2147        )
2148        kwargs = {"lens_type": "plane-parallel"}
2149        return args, kwargs
2150
2151    def finalize_image(self, image):
2152        ds = self.ds
2153        dd = ds.all_data()
2154        field = dd._determine_fields([self.field])[0]
2155        finfo = ds._get_field_info(*field)
2156        dl = 1.0
2157        if self.method == "integrate":
2158            if self.weight is None:
2159                dl = self.width[2].in_units(ds.unit_system["length"])
2160            else:
2161                image[:, :, 0] /= image[:, :, 1]
2162
2163        return ImageArray(image[:, :, 0], finfo.units, registry=ds.unit_registry) * dl
2164
2165    def _render(self, double_check, num_threads, image, sampler):
2166        # Calculate the eight corners of the box
2167        # Back corners ...
2168        if self.interpolated:
2169            return Camera._render(self, double_check, num_threads, image, sampler)
2170        ds = self.ds
2171        width = self.width[2]
2172        north_vector = self.orienter.unit_vectors[0]
2173        east_vector = self.orienter.unit_vectors[1]
2174        normal_vector = self.orienter.unit_vectors[2]
2175        fields = self.fields
2176
2177        mi = ds.domain_right_edge.copy()
2178        ma = ds.domain_left_edge.copy()
2179        for off1 in [-1, 1]:
2180            for off2 in [-1, 1]:
2181                for off3 in [-1, 1]:
2182                    this_point = (
2183                        self.center
2184                        + width / 2.0 * off1 * north_vector
2185                        + width / 2.0 * off2 * east_vector
2186                        + width / 2.0 * off3 * normal_vector
2187                    )
2188                    np.minimum(mi, this_point, mi)
2189                    np.maximum(ma, this_point, ma)
2190        # Now we have a bounding box.
2191        data_source = ds.region(self.center, mi, ma)
2192
2193        for (grid, mask) in data_source.blocks:
2194            data = [(grid[field] * mask).astype("float64") for field in fields]
2195            pg = PartitionedGrid(
2196                grid.id,
2197                data,
2198                mask.astype("uint8"),
2199                grid.LeftEdge,
2200                grid.RightEdge,
2201                grid.ActiveDimensions.astype("int64"),
2202            )
2203            grid.clear_data()
2204            sampler(pg, num_threads=num_threads)
2205
2206        image = self.finalize_image(sampler.aimage)
2207        return image
2208
2209    def save_image(self, image, fn=None, clip_ratio=None):
2210        dd = self.ds.all_data()
2211        field = dd._determine_fields([self.field])[0]
2212        finfo = self.ds._get_field_info(*field)
2213        if finfo.take_log:
2214            im = np.log10(image)
2215        else:
2216            im = image
2217        if self.comm.rank == 0 and fn is not None:
2218            if clip_ratio is not None:
2219                write_image(im, fn)
2220            else:
2221                write_image(im, fn)
2222
2223    def snapshot(self, fn=None, clip_ratio=None, double_check=False, num_threads=0):
2224
2225        if num_threads is None:
2226            num_threads = get_num_threads()
2227
2228        image = self.new_image()
2229
2230        args, kwargs = self.get_sampler_args(image)
2231
2232        sampler = self.get_sampler(args, kwargs)
2233
2234        self.initialize_source()
2235
2236        image = ImageArray(
2237            self._render(double_check, num_threads, image, sampler),
2238            info=self.get_information(),
2239        )
2240
2241        self.save_image(image, fn=fn, clip_ratio=clip_ratio)
2242
2243        return image
2244
2245    snapshot.__doc__ = Camera.snapshot.__doc__
2246
2247
2248data_object_registry["projection_camera"] = ProjectionCamera
2249
2250
2251class SphericalCamera(Camera):
2252    def __init__(self, *args, **kwargs):
2253        Camera.__init__(self, *args, **kwargs)
2254        if self.resolution[0] / self.resolution[1] != 2:
2255            mylog.info("Warning: It's recommended to set the aspect ratio to 2:1")
2256        self.resolution = np.asarray(self.resolution) + 2
2257
2258    def get_sampler_args(self, image):
2259        px = np.linspace(-np.pi, np.pi, self.resolution[0], endpoint=True)[:, None]
2260        py = np.linspace(-np.pi / 2.0, np.pi / 2.0, self.resolution[1], endpoint=True)[
2261            None, :
2262        ]
2263
2264        vectors = np.zeros(
2265            (self.resolution[0], self.resolution[1], 3), dtype="float64", order="C"
2266        )
2267        vectors[:, :, 0] = np.cos(px) * np.cos(py)
2268        vectors[:, :, 1] = np.sin(px) * np.cos(py)
2269        vectors[:, :, 2] = np.sin(py)
2270
2271        vectors = vectors * self.width[0]
2272        positions = self.center + vectors * 0
2273        R1 = get_rotation_matrix(0.5 * np.pi, [1, 0, 0])
2274        R2 = get_rotation_matrix(0.5 * np.pi, [0, 0, 1])
2275        uv = np.dot(R1, self.orienter.unit_vectors)
2276        uv = np.dot(R2, uv)
2277        vectors.reshape((self.resolution[0] * self.resolution[1], 3))
2278        vectors = np.dot(vectors, uv)
2279        vectors.reshape((self.resolution[0], self.resolution[1], 3))
2280
2281        dummy = np.ones(3, dtype="float64")
2282        image.shape = (self.resolution[0] * self.resolution[1], 1, 4)
2283        vectors.shape = (self.resolution[0] * self.resolution[1], 1, 3)
2284        positions.shape = (self.resolution[0] * self.resolution[1], 1, 3)
2285        args = (
2286            positions,
2287            vectors,
2288            self.back_center,
2289            (0.0, 1.0, 0.0, 1.0),
2290            image,
2291            dummy,
2292            dummy,
2293            np.zeros(3, dtype="float64"),
2294            self.transfer_function,
2295            self.sub_samples,
2296        )
2297        return args, {"lens_type": "spherical"}
2298
2299    def _render(self, double_check, num_threads, image, sampler):
2300        ncells = sum(b.source_mask.size for b in self.volume.bricks)
2301        pbar = get_pbar("Ray casting", ncells)
2302        total_cells = 0
2303        if double_check:
2304            for brick in self.volume.bricks:
2305                for data in brick.my_data:
2306                    if np.any(np.isnan(data)):
2307                        raise RuntimeError
2308
2309        for brick in self.volume.traverse(self.front_center):
2310            sampler(brick, num_threads=num_threads)
2311            total_cells += brick.source_mask.size
2312            pbar.update(total_cells)
2313
2314        pbar.finish()
2315        image = self.finalize_image(sampler.aimage)
2316        return image
2317
2318    def finalize_image(self, image):
2319        view_pos = self.front_center
2320        image.shape = self.resolution[0], self.resolution[1], 4
2321        image = self.volume.reduce_tree_images(image, view_pos)
2322        if not self.transfer_function.grey_opacity:
2323            image[:, :, 3] = 1.0
2324        image = image[1:-1, 1:-1, :]
2325        return image
2326
2327
2328data_object_registry["spherical_camera"] = SphericalCamera
2329
2330
2331class StereoSphericalCamera(Camera):
2332    def __init__(self, *args, **kwargs):
2333        self.disparity = kwargs.pop("disparity", 0.0)
2334        Camera.__init__(self, *args, **kwargs)
2335        self.disparity = self.ds.arr(self.disparity, units="code_length")
2336        self.disparity_s = self.ds.arr(0.0, units="code_length")
2337        if self.resolution[0] / self.resolution[1] != 2:
2338            mylog.info("Warning: It's recommended to set the aspect ratio to be 2:1")
2339        self.resolution = np.asarray(self.resolution) + 2
2340        if self.disparity <= 0.0:
2341            self.disparity = self.width[0] / 1000.0
2342            mylog.info(
2343                "Warning: Invalid value of disparity; now reset it to %f",
2344                self.disparity,
2345            )
2346
2347    def get_sampler_args(self, image):
2348        px = np.linspace(-np.pi, np.pi, self.resolution[0], endpoint=True)[:, None]
2349        py = np.linspace(-np.pi / 2.0, np.pi / 2.0, self.resolution[1], endpoint=True)[
2350            None, :
2351        ]
2352
2353        vectors = np.zeros(
2354            (self.resolution[0], self.resolution[1], 3), dtype="float64", order="C"
2355        )
2356        vectors[:, :, 0] = np.cos(px) * np.cos(py)
2357        vectors[:, :, 1] = np.sin(px) * np.cos(py)
2358        vectors[:, :, 2] = np.sin(py)
2359        vectors2 = np.zeros(
2360            (self.resolution[0], self.resolution[1], 3), dtype="float64", order="C"
2361        )
2362        vectors2[:, :, 0] = -np.sin(px) * np.ones((1, self.resolution[1]))
2363        vectors2[:, :, 1] = np.cos(px) * np.ones((1, self.resolution[1]))
2364        vectors2[:, :, 2] = 0
2365
2366        positions = self.center + vectors2 * self.disparity_s
2367        vectors = vectors * self.width[0]
2368        R1 = get_rotation_matrix(0.5 * np.pi, [1, 0, 0])
2369        R2 = get_rotation_matrix(0.5 * np.pi, [0, 0, 1])
2370        uv = np.dot(R1, self.orienter.unit_vectors)
2371        uv = np.dot(R2, uv)
2372        vectors.reshape((self.resolution[0] * self.resolution[1], 3))
2373        vectors = np.dot(vectors, uv)
2374        vectors.reshape((self.resolution[0], self.resolution[1], 3))
2375
2376        dummy = np.ones(3, dtype="float64")
2377        image.shape = (self.resolution[0] * self.resolution[1], 1, 4)
2378        vectors.shape = (self.resolution[0] * self.resolution[1], 1, 3)
2379        positions.shape = (self.resolution[0] * self.resolution[1], 1, 3)
2380        args = (
2381            positions,
2382            vectors,
2383            self.back_center,
2384            (0.0, 1.0, 0.0, 1.0),
2385            image,
2386            dummy,
2387            dummy,
2388            np.zeros(3, dtype="float64"),
2389            "KDTree",
2390            self.transfer_function,
2391            self.sub_samples,
2392        )
2393        kwargs = {"lens_type": "stereo-spherical"}
2394        return args, kwargs
2395
2396    def snapshot(
2397        self,
2398        fn=None,
2399        clip_ratio=None,
2400        double_check=False,
2401        num_threads=0,
2402        transparent=False,
2403    ):
2404
2405        if num_threads is None:
2406            num_threads = get_num_threads()
2407
2408        self.disparity_s = self.disparity
2409        image1 = self.new_image()
2410        args1, kwargs1 = self.get_sampler_args(image1)
2411        sampler1 = self.get_sampler(args1, kwargs1)
2412        self.initialize_source()
2413        image1 = self._render(double_check, num_threads, image1, sampler1, "(Left) ")
2414
2415        self.disparity_s = -self.disparity
2416        image2 = self.new_image()
2417        args2, kwargs2 = self.get_sampler_args(image2)
2418        sampler2 = self.get_sampler(args2, kwargs2)
2419        self.initialize_source()
2420        image2 = self._render(double_check, num_threads, image2, sampler2, "(Right)")
2421
2422        image = np.hstack([image1, image2])
2423        image = self.volume.reduce_tree_images(image, self.center)
2424        image = ImageArray(image, info=self.get_information())
2425        self.save_image(image, fn=fn, clip_ratio=clip_ratio, transparent=transparent)
2426        return image
2427
2428    def _render(self, double_check, num_threads, image, sampler, msg):
2429        ncells = sum(b.source_mask.size for b in self.volume.bricks)
2430        pbar = get_pbar("Ray casting " + msg, ncells)
2431        total_cells = 0
2432        if double_check:
2433            for brick in self.volume.bricks:
2434                for data in brick.my_data:
2435                    if np.any(np.isnan(data)):
2436                        raise RuntimeError
2437
2438        for brick in self.volume.traverse(self.front_center):
2439            sampler(brick, num_threads=num_threads)
2440            total_cells += brick.source_mask.size
2441            pbar.update(total_cells)
2442
2443        pbar.finish()
2444
2445        image = sampler.aimage.copy()
2446        image.shape = self.resolution[0], self.resolution[1], 4
2447        if not self.transfer_function.grey_opacity:
2448            image[:, :, 3] = 1.0
2449        image = image[1:-1, 1:-1, :]
2450        return image
2451
2452
2453data_object_registry["stereospherical_camera"] = StereoSphericalCamera
2454
2455
2456def off_axis_projection(
2457    ds,
2458    center,
2459    normal_vector,
2460    width,
2461    resolution,
2462    field,
2463    weight=None,
2464    volume=None,
2465    no_ghost=False,
2466    interpolated=False,
2467    north_vector=None,
2468    method="integrate",
2469):
2470    r"""Project through a dataset, off-axis, and return the image plane.
2471
2472    This function will accept the necessary items to integrate through a volume
2473    at an arbitrary angle and return the integrated field of view to the user.
2474    Note that if a weight is supplied, it will multiply the pre-interpolated
2475    values together, then create cell-centered values, then interpolate within
2476    the cell to conduct the integration.
2477
2478    Parameters
2479    ----------
2480    ds : ~yt.data_objects.static_output.Dataset
2481        This is the dataset to volume render.
2482    center : array_like
2483        The current 'center' of the view port -- the focal point for the
2484        camera.
2485    normal_vector : array_like
2486        The vector between the camera position and the center.
2487    width : float or list of floats
2488        The current width of the image.  If a single float, the volume is
2489        cubical, but if not, it is left/right, top/bottom, front/back
2490    resolution : int or list of ints
2491        The number of pixels in each direction.
2492    field : string
2493        The field to project through the volume
2494    weight : optional, default None
2495        If supplied, the field will be pre-multiplied by this, then divided by
2496        the integrated value of this field.  This returns an average rather
2497        than a sum.
2498    volume : `yt.extensions.volume_rendering.AMRKDTree`, optional
2499        The volume to ray cast through.  Can be specified for finer-grained
2500        control, but otherwise will be automatically generated.
2501    no_ghost: bool, optional
2502        Optimization option.  If True, homogenized bricks will
2503        extrapolate out from grid instead of interpolating from
2504        ghost zones that have to first be calculated.  This can
2505        lead to large speed improvements, but at a loss of
2506        accuracy/smoothness in resulting image.  The effects are
2507        less notable when the transfer function is smooth and
2508        broad. Default: True
2509    interpolated : optional, default False
2510        If True, the data is first interpolated to vertex-centered data,
2511        then tri-linearly interpolated along the ray. Not suggested for
2512        quantitative studies.
2513    method : string
2514         The method of projection.  Valid methods are:
2515
2516         "integrate" with no weight_field specified : integrate the requested
2517         field along the line of sight.
2518
2519         "integrate" with a weight_field specified : weight the requested
2520         field by the weighting field and integrate along the line of sight.
2521
2522         "sum" : This method is the same as integrate, except that it does not
2523         multiply by a path length when performing the integration, and is
2524         just a straight summation of the field along the given axis. WARNING:
2525         This should only be used for uniform resolution grid datasets, as other
2526         datasets may result in unphysical images.
2527
2528    Returns
2529    -------
2530    image : array
2531        An (N,N) array of the final integrated values, in float64 form.
2532
2533    Examples
2534    --------
2535
2536    >>> image = off_axis_projection(
2537    ...     ds, [0.5, 0.5, 0.5], [0.2, 0.3, 0.4], 0.2, N, "temperature", "density"
2538    ... )
2539    >>> write_image(np.log10(image), "offaxis.png")
2540
2541    """
2542    projcam = ProjectionCamera(
2543        center,
2544        normal_vector,
2545        width,
2546        resolution,
2547        field,
2548        weight=weight,
2549        ds=ds,
2550        volume=volume,
2551        no_ghost=no_ghost,
2552        interpolated=interpolated,
2553        north_vector=north_vector,
2554        method=method,
2555    )
2556    image = projcam.snapshot()
2557    return image[:, :]
2558