1import abc
2from functools import wraps
3
4import numpy as np
5
6from yt.config import ytcfg
7from yt.data_objects.image_array import ImageArray
8from yt.funcs import ensure_numpy_array, is_sequence, mylog
9from yt.geometry.grid_geometry_handler import GridIndex
10from yt.geometry.oct_geometry_handler import OctreeIndex
11from yt.utilities.amr_kdtree.api import AMRKDTree
12from yt.utilities.lib.bounding_volume_hierarchy import BVH
13from yt.utilities.lib.misc_utilities import zlines, zpoints
14from yt.utilities.lib.octree_raytracing import OctreeRayTracing
15from yt.utilities.lib.partitioned_grid import PartitionedGrid
16from yt.utilities.on_demand_imports import NotAModule
17from yt.utilities.parallel_tools.parallel_analysis_interface import (
18    ParallelAnalysisInterface,
19)
20from yt.visualization.image_writer import apply_colormap
21
22from .transfer_function_helper import TransferFunctionHelper
23from .transfer_functions import (
24    ColorTransferFunction,
25    ProjectionTransferFunction,
26    TransferFunction,
27)
28from .utils import (
29    data_source_or_all,
30    get_corners,
31    new_interpolated_projection_sampler,
32    new_mesh_sampler,
33    new_projection_sampler,
34    new_volume_render_sampler,
35)
36from .zbuffer_array import ZBuffer
37
38try:
39    from yt.utilities.lib.embree_mesh import mesh_traversal
40# Catch ValueError in case size of objects in Cython change
41except (ImportError, ValueError):
42    mesh_traversal = NotAModule("pyembree")
43    ytcfg["yt", "ray_tracing_engine"] = "yt"
44try:
45    from yt.utilities.lib.embree_mesh import mesh_construction
46# Catch ValueError in case size of objects in Cython change
47except (ImportError, ValueError):
48    mesh_construction = NotAModule("pyembree")
49    ytcfg["yt", "ray_tracing_engine"] = "yt"
50
51
52def invalidate_volume(f):
53    @wraps(f)
54    def wrapper(*args, **kwargs):
55        ret = f(*args, **kwargs)
56        obj = args[0]
57        if isinstance(obj._transfer_function, ProjectionTransferFunction):
58            obj.sampler_type = "projection"
59            obj._log_field = False
60            obj._use_ghost_zones = False
61        del obj.volume
62        obj._volume_valid = False
63        return ret
64
65    return wrapper
66
67
68def validate_volume(f):
69    @wraps(f)
70    def wrapper(*args, **kwargs):
71        obj = args[0]
72        fields = [obj.field]
73        log_fields = [obj.log_field]
74        if obj.weight_field is not None:
75            fields.append(obj.weight_field)
76            log_fields.append(obj.log_field)
77        if not obj._volume_valid:
78            obj.volume.set_fields(
79                fields, log_fields, no_ghost=(not obj.use_ghost_zones)
80            )
81        obj._volume_valid = True
82        return f(*args, **kwargs)
83
84    return wrapper
85
86
87class RenderSource(ParallelAnalysisInterface):
88
89    """Base Class for Render Sources.
90
91    Will be inherited for volumes, streamlines, etc.
92
93    """
94
95    volume_method = None
96
97    def __init__(self):
98        super().__init__()
99        self.opaque = False
100        self.zbuffer = None
101
102    @abc.abstractmethod
103    def render(self, camera, zbuffer=None):
104        pass
105
106    @abc.abstractmethod
107    def _validate(self):
108        pass
109
110
111class OpaqueSource(RenderSource):
112    """A base class for opaque render sources.
113
114    Will be inherited from for LineSources, BoxSources, etc.
115
116    """
117
118    def __init__(self):
119        super().__init__()
120        self.opaque = True
121
122    def set_zbuffer(self, zbuffer):
123        self.zbuffer = zbuffer
124
125
126def create_volume_source(data_source, field):
127    data_source = data_source_or_all(data_source)
128    ds = data_source.ds
129    index_class = ds.index.__class__
130    if issubclass(index_class, GridIndex):
131        return KDTreeVolumeSource(data_source, field)
132    elif issubclass(index_class, OctreeIndex):
133        return OctreeVolumeSource(data_source, field)
134    else:
135        raise NotImplementedError
136
137
138class VolumeSource(RenderSource, abc.ABC):
139    """A class for rendering data from a volumetric data source
140
141    Examples of such sources include a sphere, cylinder, or the
142    entire computational domain.
143
144    A :class:`VolumeSource` provides the framework to decompose an arbitrary
145    yt data source into bricks that can be traversed and volume rendered.
146
147    Parameters
148    ----------
149    data_source: :class:`AMR3DData` or :class:`Dataset`, optional
150        This is the source to be rendered, which can be any arbitrary yt
151        data object or dataset.
152    field : string
153        The name of the field to be rendered.
154
155    Examples
156    --------
157
158    The easiest way to make a VolumeSource is to use the volume_render
159    function, so that the VolumeSource gets created automatically. This
160    example shows how to do this and then access the resulting source:
161
162    >>> import yt
163    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
164    >>> im, sc = yt.volume_render(ds)
165    >>> volume_source = sc.get_source(0)
166
167    You can also create VolumeSource instances by hand and add them to Scenes.
168    This example manually creates a VolumeSource, adds it to a scene, sets the
169    camera, and renders an image.
170
171    >>> import yt
172    >>> from yt.visualization.volume_rendering.api import (
173    ...     Camera, Scene, create_volume_source)
174    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
175    >>> sc = Scene()
176    >>> source = create_volume_source(ds.all_data(), "density")
177    >>> sc.add_source(source)
178    >>> sc.add_camera()
179    >>> im = sc.render()
180
181    """
182
183    _image = None
184    data_source = None
185    volume_method = None
186
187    def __init__(self, data_source, field):
188        r"""Initialize a new volumetric source for rendering."""
189        super().__init__()
190        self.data_source = data_source_or_all(data_source)
191        field = self.data_source._determine_fields(field)[0]
192        self.current_image = None
193        self.check_nans = False
194        self.num_threads = 0
195        self.num_samples = 10
196        self.sampler_type = "volume-render"
197
198        self._volume_valid = False
199
200        # these are caches for properties, defined below
201        self._volume = None
202        self._transfer_function = None
203        self._field = field
204        self._log_field = self.data_source.ds.field_info[field].take_log
205        self._use_ghost_zones = False
206        self._weight_field = None
207
208        self.tfh = TransferFunctionHelper(self.data_source.pf)
209        self.tfh.set_field(self.field)
210
211    @property
212    def transfer_function(self):
213        """The transfer function associated with this VolumeSource"""
214        if self._transfer_function is not None:
215            return self._transfer_function
216
217        if self.tfh.tf is not None:
218            self._transfer_function = self.tfh.tf
219            return self._transfer_function
220
221        mylog.info("Creating transfer function")
222        self.tfh.set_field(self.field)
223        self.tfh.set_log(self.log_field)
224        self.tfh.build_transfer_function()
225        self.tfh.setup_default()
226        self._transfer_function = self.tfh.tf
227
228        return self._transfer_function
229
230    @transfer_function.setter
231    def transfer_function(self, value):
232        self.tfh.tf = None
233        valid_types = (
234            TransferFunction,
235            ColorTransferFunction,
236            ProjectionTransferFunction,
237            type(None),
238        )
239        if not isinstance(value, valid_types):
240            raise RuntimeError(
241                "transfer_function not a valid type, "
242                "received object of type %s" % type(value)
243            )
244        if isinstance(value, ProjectionTransferFunction):
245            self.sampler_type = "projection"
246            if self._volume is not None:
247                fields = [self.field]
248                if self.weight_field is not None:
249                    fields.append(self.weight_field)
250                self._volume_valid = False
251        self._transfer_function = value
252
253    @property
254    def volume(self):
255        """The abstract volume associated with this VolumeSource
256
257        This object does the heavy lifting to access data in an efficient manner
258        using a KDTree
259        """
260        return self._get_volume()
261
262    @volume.setter
263    def volume(self, value):
264        assert isinstance(value, AMRKDTree)
265        del self._volume
266        self._field = value.fields
267        self._log_field = value.log_fields
268        self._volume = value
269        assert self._volume_valid
270
271    @volume.deleter
272    def volume(self):
273        del self._volume
274        self._volume = None
275
276    @property
277    def field(self):
278        """The field to be rendered"""
279        return self._field
280
281    @field.setter
282    @invalidate_volume
283    def field(self, value):
284        field = self.data_source._determine_fields(value)
285        if len(field) > 1:
286            raise RuntimeError(
287                "VolumeSource.field can only be a single field but received "
288                "multiple fields: %s"
289            ) % field
290        field = field[0]
291        if self._field != field:
292            log_field = self.data_source.ds.field_info[field].take_log
293            self.tfh.bounds = None
294        else:
295            log_field = self._log_field
296        self._log_field = log_field
297        self._field = value
298        self.transfer_function = None
299        self.tfh.set_field(value)
300        self.tfh.set_log(log_field)
301
302    @property
303    def log_field(self):
304        """Whether or not the field rendering is computed in log space"""
305        return self._log_field
306
307    @log_field.setter
308    @invalidate_volume
309    def log_field(self, value):
310        self.transfer_function = None
311        self.tfh.set_log(value)
312        self._log_field = value
313
314    @property
315    def use_ghost_zones(self):
316        """Whether or not ghost zones are used to estimate vertex-centered data
317        values at grid boundaries"""
318        return self._use_ghost_zones
319
320    @use_ghost_zones.setter
321    @invalidate_volume
322    def use_ghost_zones(self, value):
323        self._use_ghost_zones = value
324
325    @property
326    def weight_field(self):
327        """The weight field for the rendering
328
329        Currently this is only used for off-axis projections.
330        """
331        return self._weight_field
332
333    @weight_field.setter
334    @invalidate_volume
335    def weight_field(self, value):
336        self._weight_field = value
337
338    def set_transfer_function(self, transfer_function):
339        """Set transfer function for this source"""
340        self.transfer_function = transfer_function
341        return self
342
343    def _validate(self):
344        """Make sure that all dependencies have been met"""
345        if self.data_source is None:
346            raise RuntimeError("Data source not initialized")
347
348    def set_volume(self, volume):
349        """Associates an AMRKDTree with the VolumeSource"""
350        self.volume = volume
351        return self
352
353    def set_field(self, field):
354        """Set the source's field to render
355
356        Parameters
357        ----------
358
359        field: field name
360            The field to render
361        """
362        self.field = field
363        return self
364
365    def set_log(self, log_field):
366        """Set whether the rendering of the source's field is done in log space
367
368        Generally volume renderings of data whose values span a large dynamic
369        range should be done on log space and volume renderings of data with
370        small dynamic range should be done in linear space.
371
372        Parameters
373        ----------
374
375        log_field: boolean
376            If True, the volume rendering will be done in log space, and if False
377            will be done in linear space.
378        """
379        self.log_field = log_field
380        return self
381
382    def set_weight_field(self, weight_field):
383        """Set the source's weight field
384
385        .. note::
386
387          This is currently only used for renderings using the
388          ProjectionTransferFunction
389
390        Parameters
391        ----------
392
393        weight_field: field name
394            The weight field to use in the rendering
395        """
396        self.weight_field = weight_field
397        return self
398
399    def set_use_ghost_zones(self, use_ghost_zones):
400        """Set whether or not interpolation at grid edges uses ghost zones
401
402        Parameters
403        ----------
404
405        use_ghost_zones: boolean
406            If True, the AMRKDTree estimates vertex centered data using ghost
407            zones, which can eliminate seams in the resulting volume rendering.
408            Defaults to False for performance reasons.
409
410        """
411        self.use_ghost_zones = use_ghost_zones
412        return self
413
414    def set_sampler(self, camera, interpolated=True):
415        """Sets a volume render sampler
416
417        The type of sampler is determined based on the ``sampler_type`` attribute
418        of the VolumeSource. Currently the ``volume_render`` and ``projection``
419        sampler types are supported.
420
421        The 'interpolated' argument is only meaningful for projections. If True,
422        the data is first interpolated to the cell vertices, and then
423        tri-linearly interpolated to the ray sampling positions. If False, then
424        the cell-centered data is simply accumulated along the
425        ray. Interpolation is always performed for volume renderings.
426
427        """
428        if self.sampler_type == "volume-render":
429            sampler = new_volume_render_sampler(camera, self)
430        elif self.sampler_type == "projection" and interpolated:
431            sampler = new_interpolated_projection_sampler(camera, self)
432        elif self.sampler_type == "projection":
433            sampler = new_projection_sampler(camera, self)
434        else:
435            NotImplementedError(f"{self.sampler_type} not implemented yet")
436        self.sampler = sampler
437        assert self.sampler is not None
438
439    @abc.abstractmethod
440    def _get_volume(self):
441        """The abstract volume associated with this VolumeSource
442
443        This object does the heavy lifting to access data in an efficient manner
444        using a KDTree
445        """
446        pass
447
448    @abc.abstractmethod
449    @validate_volume
450    def render(self, camera, zbuffer=None):
451        """Renders an image using the provided camera
452
453        Parameters
454        ----------
455        camera: :class:`yt.visualization.volume_rendering.camera.Camera` instance
456            A volume rendering camera. Can be any type of camera.
457        zbuffer: :class:`yt.visualization.volume_rendering.zbuffer_array.Zbuffer` instance  # noqa: E501
458            A zbuffer array. This is used for opaque sources to determine the
459            z position of the source relative to other sources. Only useful if
460            you are manually calling render on multiple sources. Scene.render
461            uses this internally.
462
463        Returns
464        -------
465        A :class:`yt.data_objects.image_array.ImageArray` instance containing
466        the rendered image.
467
468        """
469        pass
470
471    def finalize_image(self, camera, image):
472        """Parallel reduce the image.
473
474        Parameters
475        ----------
476        camera: :class:`yt.visualization.volume_rendering.camera.Camera` instance
477            The camera used to produce the volume rendering image.
478        image: :class:`yt.data_objects.image_array.ImageArray` instance
479            A reference to an image to fill
480        """
481        image.shape = camera.resolution[0], camera.resolution[1], 4
482        # If the call is from VR, the image is rotated by 180 to get correct
483        # up direction
484        if not self.transfer_function.grey_opacity:
485            image[:, :, 3] = 1
486        return image
487
488    def __repr__(self):
489        disp = f"<Volume Source>:{str(self.data_source)} "
490        disp += f"transfer_function:{str(self._transfer_function)}"
491        return disp
492
493
494class KDTreeVolumeSource(VolumeSource):
495    volume_method = "KDTree"
496
497    def _get_volume(self):
498        """The abstract volume associated with this VolumeSource
499
500        This object does the heavy lifting to access data in an efficient manner
501        using a KDTree
502        """
503
504        if self._volume is None:
505            mylog.info("Creating volume")
506            volume = AMRKDTree(self.data_source.ds, data_source=self.data_source)
507            self._volume = volume
508
509        return self._volume
510
511    @validate_volume
512    def render(self, camera, zbuffer=None):
513        """Renders an image using the provided camera
514
515        Parameters
516        ----------
517        camera: :class:`yt.visualization.volume_rendering.camera.Camera`
518            A volume rendering camera. Can be any type of camera.
519        zbuffer: :class:`yt.visualization.volume_rendering.zbuffer_array.Zbuffer`
520            A zbuffer array. This is used for opaque sources to determine the
521            z position of the source relative to other sources. Only useful if
522            you are manually calling render on multiple sources. Scene.render
523            uses this internally.
524
525        Returns
526        -------
527        A :class:`yt.data_objects.image_array.ImageArray` containing
528        the rendered image.
529
530        """
531        self.zbuffer = zbuffer
532        self.set_sampler(camera)
533        assert self.sampler is not None
534
535        mylog.debug("Casting rays")
536        total_cells = 0
537        if self.check_nans:
538            for brick in self.volume.bricks:
539                for data in brick.my_data:
540                    if np.any(np.isnan(data)):
541                        raise RuntimeError
542
543        for brick in self.volume.traverse(camera.lens.viewpoint):
544            mylog.debug("Using sampler %s", self.sampler)
545            self.sampler(brick, num_threads=self.num_threads)
546            total_cells += np.prod(brick.my_data[0].shape)
547        mylog.debug("Done casting rays")
548        self.current_image = self.finalize_image(camera, self.sampler.aimage)
549
550        if zbuffer is None:
551            self.zbuffer = ZBuffer(
552                self.current_image, np.full(self.current_image.shape[:2], np.inf)
553            )
554
555        return self.current_image
556
557    def finalize_image(self, camera, image):
558        if self._volume is not None:
559            image = self.volume.reduce_tree_images(image, camera.lens.viewpoint)
560
561        return super().finalize_image(camera, image)
562
563
564class OctreeVolumeSource(VolumeSource):
565    volume_method = "Octree"
566
567    def __init__(self, *args, **kwa):
568        super().__init__(*args, **kwa)
569        self.set_use_ghost_zones(True)
570
571    def _get_volume(self):
572        """The abstract volume associated with this VolumeSource
573
574        This object does the heavy lifting to access data in an efficient manner
575        using an octree.
576        """
577
578        if self._volume is None:
579            mylog.info("Creating volume")
580            volume = OctreeRayTracing(self.data_source)
581            self._volume = volume
582
583        return self._volume
584
585    @validate_volume
586    def render(self, camera, zbuffer=None):
587        """Renders an image using the provided camera
588
589        Parameters
590        ----------
591        camera: :class:`yt.visualization.volume_rendering.camera.Camera` instance
592            A volume rendering camera. Can be any type of camera.
593        zbuffer: :class:`yt.visualization.volume_rendering.zbuffer_array.Zbuffer` instance  # noqa: E501
594            A zbuffer array. This is used for opaque sources to determine the
595            z position of the source relative to other sources. Only useful if
596            you are manually calling render on multiple sources. Scene.render
597            uses this internally.
598
599        Returns
600        -------
601        A :class:`yt.data_objects.image_array.ImageArray` instance containing
602        the rendered image.
603
604        """
605        self.zbuffer = zbuffer
606        self.set_sampler(camera)
607        if self.sampler is None:
608            raise RuntimeError(
609                "No sampler set. This is likely a bug as it should never happen."
610            )
611
612        data = self.data_source
613
614        dx = data["dx"].to("unitary").value[:, None]
615        xyz = np.stack([data[_].to("unitary").value for _ in "x y z".split()], axis=-1)
616        LE = xyz - dx / 2
617        RE = xyz + dx / 2
618
619        mylog.debug("Gathering data")
620        dt = np.stack(list(self.volume.data) + [*LE.T, *RE.T], axis=-1).reshape(
621            1, len(dx), 14, 1
622        )
623        mask = np.full(dt.shape[1:], 1, dtype=np.uint8)
624        dims = np.array([1, 1, 1], dtype="int64")
625        pg = PartitionedGrid(0, dt, mask, LE.flatten(), RE.flatten(), dims, n_fields=1)
626
627        mylog.debug("Casting rays")
628        self.sampler(pg, oct=self.volume.octree)
629        mylog.debug("Done casting rays")
630
631        self.current_image = self.finalize_image(camera, self.sampler.aimage)
632
633        if zbuffer is None:
634            self.zbuffer = ZBuffer(
635                self.current_image, np.full(self.current_image.shape[:2], np.inf)
636            )
637
638        return self.current_image
639
640
641class MeshSource(OpaqueSource):
642    """A source for unstructured mesh data.
643
644    This functionality requires the embree ray-tracing engine and the
645    associated pyembree python bindings to be installed in order to
646    function.
647
648    A :class:`MeshSource` provides the framework to volume render
649    unstructured mesh data.
650
651    Parameters
652    ----------
653    data_source: :class:`AMR3DData` or :class:`Dataset`, optional
654        This is the source to be rendered, which can be any arbitrary yt
655        data object or dataset.
656    field : string
657        The name of the field to be rendered.
658
659    Examples
660    --------
661    >>> source = MeshSource(ds, ("connect1", "convected"))
662    """
663
664    _image = None
665    data_source = None
666
667    def __init__(self, data_source, field):
668        r"""Initialize a new unstructured mesh source for rendering."""
669        super().__init__()
670        self.data_source = data_source_or_all(data_source)
671        field = self.data_source._determine_fields(field)[0]
672        self.field = field
673        self.volume = None
674        self.current_image = None
675        self.engine = ytcfg.get("yt", "ray_tracing_engine")
676
677        # default color map
678        self._cmap = ytcfg.get("yt", "default_colormap")
679        self._color_bounds = None
680
681        # default mesh annotation options
682        self._annotate_mesh = False
683        self._mesh_line_color = None
684        self._mesh_line_alpha = 1.0
685
686        # Error checking
687        assert self.field is not None
688        assert self.data_source is not None
689        if self.field[0] == "all":
690            raise NotImplementedError(
691                "Mesh unions are not implemented for 3D rendering"
692            )
693
694        if self.engine == "embree":
695            self.volume = mesh_traversal.YTEmbreeScene()
696            self.build_volume_embree()
697        elif self.engine == "yt":
698            self.build_volume_bvh()
699        else:
700            raise NotImplementedError(
701                "Invalid ray-tracing engine selected. Choices are 'embree' and 'yt'."
702            )
703
704    def cmap():
705        """
706        This is the name of the colormap that will be used when rendering
707        this MeshSource object. Should be a string, like 'arbre', or 'dusk'.
708
709        """
710
711        def fget(self):
712            return self._cmap
713
714        def fset(self, cmap_name):
715            self._cmap = cmap_name
716            if hasattr(self, "data"):
717                self.current_image = self.apply_colormap()
718
719        return locals()
720
721    cmap = property(**cmap())
722
723    def color_bounds():
724        """
725        These are the bounds that will be used with the colormap to the display
726        the rendered image. Should be a (vmin, vmax) tuple, like (0.0, 2.0). If
727        None, the bounds will be automatically inferred from the max and min of
728        the rendered data.
729
730        """
731
732        def fget(self):
733            return self._color_bounds
734
735        def fset(self, bounds):
736            self._color_bounds = bounds
737            if hasattr(self, "data"):
738                self.current_image = self.apply_colormap()
739
740        return locals()
741
742    color_bounds = property(**color_bounds())
743
744    def _validate(self):
745        """Make sure that all dependencies have been met"""
746        if self.data_source is None:
747            raise RuntimeError("Data source not initialized.")
748
749        if self.volume is None:
750            raise RuntimeError("Volume not initialized.")
751
752    def build_volume_embree(self):
753        """
754
755        This constructs the mesh that will be ray-traced by pyembree.
756
757        """
758        ftype, fname = self.field
759        mesh_id = int(ftype[-1]) - 1
760        index = self.data_source.ds.index
761        offset = index.meshes[mesh_id]._index_offset
762        field_data = self.data_source[self.field].d  # strip units
763
764        vertices = index.meshes[mesh_id].connectivity_coords
765        indices = index.meshes[mesh_id].connectivity_indices - offset
766
767        # if this is an element field, promote to 2D here
768        if len(field_data.shape) == 1:
769            field_data = np.expand_dims(field_data, 1)
770
771        # Here, we decide whether to render based on high-order or
772        # low-order geometry. Right now, high-order geometry is only
773        # implemented for 20-point hexes.
774        if indices.shape[1] == 20 or indices.shape[1] == 10:
775            self.mesh = mesh_construction.QuadraticElementMesh(
776                self.volume, vertices, indices, field_data
777            )
778        else:
779            # if this is another type of higher-order element, we demote
780            # to 1st order here, for now.
781            if indices.shape[1] == 27:
782                # hexahedral
783                mylog.warning("27-node hexes not yet supported, dropping to 1st order.")
784                field_data = field_data[:, 0:8]
785                indices = indices[:, 0:8]
786
787            self.mesh = mesh_construction.LinearElementMesh(
788                self.volume, vertices, indices, field_data
789            )
790
791    def build_volume_bvh(self):
792        """
793
794        This constructs the mesh that will be ray-traced.
795
796        """
797        ftype, fname = self.field
798        mesh_id = int(ftype[-1]) - 1
799        index = self.data_source.ds.index
800        offset = index.meshes[mesh_id]._index_offset
801        field_data = self.data_source[self.field].d  # strip units
802
803        vertices = index.meshes[mesh_id].connectivity_coords
804        indices = index.meshes[mesh_id].connectivity_indices - offset
805
806        # if this is an element field, promote to 2D here
807        if len(field_data.shape) == 1:
808            field_data = np.expand_dims(field_data, 1)
809
810        # Here, we decide whether to render based on high-order or
811        # low-order geometry.
812        if indices.shape[1] == 27:
813            # hexahedral
814            mylog.warning("27-node hexes not yet supported, dropping to 1st order.")
815            field_data = field_data[:, 0:8]
816            indices = indices[:, 0:8]
817
818        self.volume = BVH(vertices, indices, field_data)
819
820    def render(self, camera, zbuffer=None):
821        """Renders an image using the provided camera
822
823        Parameters
824        ----------
825        camera: :class:`yt.visualization.volume_rendering.camera.Camera`
826            A volume rendering camera. Can be any type of camera.
827        zbuffer: :class:`yt.visualization.volume_rendering.zbuffer_array.Zbuffer`
828            A zbuffer array. This is used for opaque sources to determine the
829            z position of the source relative to other sources. Only useful if
830            you are manually calling render on multiple sources. Scene.render
831            uses this internally.
832
833        Returns
834        -------
835        A :class:`yt.data_objects.image_array.ImageArray` containing
836        the rendered image.
837
838        """
839
840        shape = (camera.resolution[0], camera.resolution[1], 4)
841        if zbuffer is None:
842            empty = np.empty(shape, dtype="float64")
843            z = np.empty(empty.shape[:2], dtype="float64")
844            empty[:] = 0.0
845            z[:] = np.inf
846            zbuffer = ZBuffer(empty, z)
847        elif zbuffer.rgba.shape != shape:
848            zbuffer = ZBuffer(zbuffer.rgba.reshape(shape), zbuffer.z.reshape(shape[:2]))
849        self.zbuffer = zbuffer
850
851        self.sampler = new_mesh_sampler(camera, self, engine=self.engine)
852
853        mylog.debug("Casting rays")
854        self.sampler(self.volume)
855        mylog.debug("Done casting rays")
856
857        self.finalize_image(camera)
858        self.current_image = self.apply_colormap()
859
860        zbuffer += ZBuffer(self.current_image.astype("float64"), self.sampler.azbuffer)
861        zbuffer.rgba = ImageArray(zbuffer.rgba)
862        self.zbuffer = zbuffer
863        self.current_image = self.zbuffer.rgba
864
865        if self._annotate_mesh:
866            self.current_image = self.annotate_mesh_lines(
867                self._mesh_line_color, self._mesh_line_alpha
868            )
869
870        return self.current_image
871
872    def finalize_image(self, camera):
873        sam = self.sampler
874
875        # reshape data
876        Nx = camera.resolution[0]
877        Ny = camera.resolution[1]
878        self.data = sam.aimage[:, :, 0].reshape(Nx, Ny)
879
880    def annotate_mesh_lines(self, color=None, alpha=1.0):
881        r"""
882
883        Modifies this MeshSource by drawing the mesh lines.
884        This modifies the current image by drawing the element
885        boundaries and returns the modified image.
886
887        Parameters
888        ----------
889        color: array_like of shape (4,), optional
890            The RGBA value to use to draw the mesh lines.
891            Default is black.
892        alpha : float, optional
893            The opacity of the mesh lines. Default is 255 (solid).
894
895        """
896
897        self.annotate_mesh = True
898        self._mesh_line_color = color
899        self._mesh_line_alpha = alpha
900
901        if color is None:
902            color = np.array([0, 0, 0, alpha])
903
904        locs = (self.sampler.amesh_lines == 1,)
905
906        self.current_image[:, :, 0][locs] = color[0]
907        self.current_image[:, :, 1][locs] = color[1]
908        self.current_image[:, :, 2][locs] = color[2]
909        self.current_image[:, :, 3][locs] = color[3]
910
911        return self.current_image
912
913    def apply_colormap(self):
914        """
915
916        Applies a colormap to the current image without re-rendering.
917
918        Returns
919        -------
920        current_image : A new image with the specified color scale applied to
921            the underlying data.
922
923
924        """
925
926        image = (
927            apply_colormap(
928                self.data, color_bounds=self._color_bounds, cmap_name=self._cmap
929            )
930            / 255.0
931        )
932        alpha = image[:, :, 3]
933        alpha[self.sampler.aimage_used == -1] = 0.0
934        image[:, :, 3] = alpha
935        return image
936
937    def __repr__(self):
938        disp = f"<Mesh Source>:{str(self.data_source)} "
939        return disp
940
941
942class PointSource(OpaqueSource):
943    r"""A rendering source of opaque points in the scene.
944
945    This class provides a mechanism for adding points to a scene; these
946    points will be opaque, and can also be colored.
947
948    Parameters
949    ----------
950    positions: array_like of shape (N, 3)
951        The positions of points to be added to the scene. If specified with no
952        units, the positions will be assumed to be in code units.
953    colors : array_like of shape (N, 4), optional
954        The colors of the points, including an alpha channel, in floating
955        point running from 0..1.
956    color_stride : int, optional
957        The stride with which to access the colors when putting them on the
958        scene.
959    radii : array_like of shape (N), optional
960        The radii of the points in the final image, in pixels (int)
961
962    Examples
963    --------
964
965    This example creates a volume rendering and adds 1000 random points to
966    the image:
967
968    >>> import yt
969    >>> import numpy as np
970    >>> from yt.visualization.volume_rendering.api import PointSource
971    >>> from yt.units import kpc
972    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
973
974    >>> im, sc = yt.volume_render(ds)
975
976    >>> npoints = 1000
977    >>> vertices = np.random.random([npoints, 3]) * 1000 * kpc
978    >>> colors = np.random.random([npoints, 4])
979    >>> colors[:, 3] = 1.0
980
981    >>> points = PointSource(vertices, colors=colors)
982    >>> sc.add_source(points)
983
984    >>> im = sc.render()
985
986    """
987
988    _image = None
989    data_source = None
990
991    def __init__(self, positions, colors=None, color_stride=1, radii=None):
992        assert positions.ndim == 2 and positions.shape[1] == 3
993        if colors is not None:
994            assert colors.ndim == 2 and colors.shape[1] == 4
995            assert colors.shape[0] == positions.shape[0]
996        if not is_sequence(radii):
997            if radii is not None:  # broadcast the value
998                radii = radii * np.ones(positions.shape[0], dtype="int64")
999            else:  # default radii to 0 pixels (i.e. point is 1 pixel wide)
1000                radii = np.zeros(positions.shape[0], dtype="int64")
1001        else:
1002            assert radii.ndim == 1
1003            assert radii.shape[0] == positions.shape[0]
1004        self.positions = positions
1005        # If colors aren't individually set, make black with full opacity
1006        if colors is None:
1007            colors = np.ones((len(positions), 4))
1008        self.colors = colors
1009        self.color_stride = color_stride
1010        self.radii = radii
1011
1012    def render(self, camera, zbuffer=None):
1013        """Renders an image using the provided camera
1014
1015        Parameters
1016        ----------
1017        camera: :class:`yt.visualization.volume_rendering.camera.Camera`
1018            A volume rendering camera. Can be any type of camera.
1019        zbuffer: :class:`yt.visualization.volume_rendering.zbuffer_array.Zbuffer`
1020            A zbuffer array. This is used for opaque sources to determine the
1021            z position of the source relative to other sources. Only useful if
1022            you are manually calling render on multiple sources. Scene.render
1023            uses this internally.
1024
1025        Returns
1026        -------
1027        A :class:`yt.data_objects.image_array.ImageArray` containing
1028        the rendered image.
1029
1030        """
1031        vertices = self.positions
1032        if zbuffer is None:
1033            empty = camera.lens.new_image(camera)
1034            z = np.empty(empty.shape[:2], dtype="float64")
1035            empty[:] = 0.0
1036            z[:] = np.inf
1037            zbuffer = ZBuffer(empty, z)
1038        else:
1039            empty = zbuffer.rgba
1040            z = zbuffer.z
1041
1042        # DRAW SOME POINTS
1043        camera.lens.setup_box_properties(camera)
1044        px, py, dz = camera.lens.project_to_plane(camera, vertices)
1045
1046        zpoints(empty, z, px, py, dz, self.colors, self.radii, self.color_stride)
1047
1048        self.zbuffer = zbuffer
1049        return zbuffer
1050
1051    def __repr__(self):
1052        disp = "<Point Source>"
1053        return disp
1054
1055
1056class LineSource(OpaqueSource):
1057    r"""A render source for a sequence of opaque line segments.
1058
1059    This class provides a mechanism for adding lines to a scene; these
1060    points will be opaque, and can also be colored.
1061
1062    .. note::
1063
1064        If adding a LineSource to your rendering causes the image to appear
1065        blank or fades a VolumeSource, try lowering the values specified in
1066        the alpha channel of the ``colors`` array.
1067
1068    Parameters
1069    ----------
1070    positions: array_like of shape (N, 2, 3)
1071        The positions of the starting and stopping points for each line.
1072        For example,positions[0][0] and positions[0][1] would give the (x, y, z)
1073        coordinates of the beginning and end points of the first line,
1074        respectively. If specified with no units, assumed to be in code units.
1075    colors : array_like of shape (N, 4), optional
1076        The colors of the points, including an alpha channel, in floating
1077        point running from 0..1.  The four channels correspond to r, g, b, and
1078        alpha values. Note that they correspond to the line segment succeeding
1079        each point; this means that strictly speaking they need only be (N-1)
1080        in length.
1081    color_stride : int, optional
1082        The stride with which to access the colors when putting them on the
1083        scene.
1084
1085    Examples
1086    --------
1087
1088    This example creates a volume rendering and then adds some random lines
1089    to the image:
1090
1091    >>> import yt
1092    >>> import numpy as np
1093    >>> from yt.visualization.volume_rendering.api import LineSource
1094    >>> from yt.units import kpc
1095    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1096
1097    >>> im, sc = yt.volume_render(ds)
1098
1099    >>> nlines = 4
1100    >>> vertices = np.random.random([nlines, 2, 3]) * 600 * kpc
1101    >>> colors = np.random.random([nlines, 4])
1102    >>> colors[:, 3] = 1.0
1103
1104    >>> lines = LineSource(vertices, colors)
1105    >>> sc.add_source(lines)
1106
1107    >>> im = sc.render()
1108
1109    """
1110
1111    _image = None
1112    data_source = None
1113
1114    def __init__(self, positions, colors=None, color_stride=1):
1115        super().__init__()
1116
1117        assert positions.ndim == 3
1118        assert positions.shape[1] == 2
1119        assert positions.shape[2] == 3
1120        if colors is not None:
1121            assert colors.ndim == 2
1122            assert colors.shape[1] == 4
1123
1124        # convert the positions to the shape expected by zlines, below
1125        N = positions.shape[0]
1126        self.positions = positions.reshape((2 * N, 3))
1127
1128        # If colors aren't individually set, make black with full opacity
1129        if colors is None:
1130            colors = np.ones((len(positions), 4))
1131        self.colors = colors
1132        self.color_stride = color_stride
1133
1134    def render(self, camera, zbuffer=None):
1135        """Renders an image using the provided camera
1136
1137        Parameters
1138        ----------
1139        camera: :class:`yt.visualization.volume_rendering.camera.Camera`
1140            A volume rendering camera. Can be any type of camera.
1141        zbuffer: :class:`yt.visualization.volume_rendering.zbuffer_array.Zbuffer`
1142            z position of the source relative to other sources. Only useful if
1143            you are manually calling render on multiple sources. Scene.render
1144            uses this internally.
1145
1146        Returns
1147        -------
1148        A :class:`yt.data_objects.image_array.ImageArray` containing
1149        the rendered image.
1150
1151        """
1152        vertices = self.positions
1153        if zbuffer is None:
1154            empty = camera.lens.new_image(camera)
1155            z = np.empty(empty.shape[:2], dtype="float64")
1156            empty[:] = 0.0
1157            z[:] = np.inf
1158            zbuffer = ZBuffer(empty, z)
1159        else:
1160            empty = zbuffer.rgba
1161            z = zbuffer.z
1162
1163        # DRAW SOME LINES
1164        camera.lens.setup_box_properties(camera)
1165        px, py, dz = camera.lens.project_to_plane(camera, vertices)
1166
1167        px = px.astype("int64")
1168        py = py.astype("int64")
1169
1170        if len(px.shape) == 1:
1171            zlines(
1172                empty, z, px, py, dz, self.colors.astype("float64"), self.color_stride
1173            )
1174        else:
1175            # For stereo-lens, two sets of pos for each eye are contained
1176            # in px...pz
1177            zlines(
1178                empty,
1179                z,
1180                px[0, :],
1181                py[0, :],
1182                dz[0, :],
1183                self.colors.astype("float64"),
1184                self.color_stride,
1185            )
1186            zlines(
1187                empty,
1188                z,
1189                px[1, :],
1190                py[1, :],
1191                dz[1, :],
1192                self.colors.astype("float64"),
1193                self.color_stride,
1194            )
1195
1196        self.zbuffer = zbuffer
1197        return zbuffer
1198
1199    def __repr__(self):
1200        disp = "<Line Source>"
1201        return disp
1202
1203
1204class BoxSource(LineSource):
1205    r"""A render source for a box drawn with line segments.
1206    This render source will draw a box, with transparent faces, in data
1207    space coordinates.  This is useful for annotations.
1208
1209    Parameters
1210    ----------
1211    left_edge: array-like of shape (3,), float
1212        The left edge coordinates of the box.
1213    right_edge : array-like of shape (3,), float
1214        The right edge coordinates of the box.
1215    color : array-like of shape (4,), float, optional
1216        The colors (including alpha) to use for the lines.
1217        Default is black with an alpha of 1.0.
1218
1219    Examples
1220    --------
1221
1222    This example shows how to use BoxSource to add an outline of the
1223    domain boundaries to a volume rendering.
1224
1225    >>> import yt
1226    >>> from yt.visualization.volume_rendering.api import BoxSource
1227    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1228
1229    >>> im, sc = yt.volume_render(ds)
1230
1231    >>> box_source = BoxSource(
1232    ...     ds.domain_left_edge, ds.domain_right_edge, [1.0, 1.0, 1.0, 1.0]
1233    ... )
1234    >>> sc.add_source(box_source)
1235
1236    >>> im = sc.render()
1237
1238    """
1239
1240    def __init__(self, left_edge, right_edge, color=None):
1241
1242        assert left_edge.shape == (3,)
1243        assert right_edge.shape == (3,)
1244
1245        if color is None:
1246            color = np.array([1.0, 1.0, 1.0, 1.0])
1247
1248        color = ensure_numpy_array(color)
1249        color.shape = (1, 4)
1250        corners = get_corners(left_edge.copy(), right_edge.copy())
1251        order = [0, 1, 1, 2, 2, 3, 3, 0]
1252        order += [4, 5, 5, 6, 6, 7, 7, 4]
1253        order += [0, 4, 1, 5, 2, 6, 3, 7]
1254        vertices = np.empty([24, 3])
1255        for i in range(3):
1256            vertices[:, i] = corners[order, i, ...].ravel(order="F")
1257        vertices = vertices.reshape((12, 2, 3))
1258
1259        super().__init__(vertices, color, color_stride=24)
1260
1261
1262class GridSource(LineSource):
1263    r"""A render source for drawing grids in a scene.
1264
1265    This render source will draw blocks that are within a given data
1266    source, by default coloring them by their level of resolution.
1267
1268    Parameters
1269    ----------
1270    data_source: :class:`~yt.data_objects.api.DataContainer`
1271        The data container that will be used to identify grids to draw.
1272    alpha : float
1273        The opacity of the grids to draw.
1274    cmap : color map name
1275        The color map to use to map resolution levels to color.
1276    min_level : int, optional
1277        Minimum level to draw
1278    max_level : int, optional
1279        Maximum level to draw
1280
1281    Examples
1282    --------
1283
1284    This example makes a volume rendering and adds outlines of all the
1285    AMR grids in the simulation:
1286
1287    >>> import yt
1288    >>> from yt.visualization.volume_rendering.api import GridSource
1289    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1290
1291    >>> im, sc = yt.volume_render(ds)
1292
1293    >>> grid_source = GridSource(ds.all_data(), alpha=1.0)
1294
1295    >>> sc.add_source(grid_source)
1296
1297    >>> im = sc.render()
1298
1299    This example does the same thing, except it only draws the grids
1300    that are inside a sphere of radius (0.1, "unitary") located at the
1301    domain center:
1302
1303    >>> import yt
1304    >>> from yt.visualization.volume_rendering.api import GridSource
1305    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1306
1307    >>> im, sc = yt.volume_render(ds)
1308
1309    >>> dd = ds.sphere("c", (0.1, "unitary"))
1310    >>> grid_source = GridSource(dd, alpha=1.0)
1311
1312    >>> sc.add_source(grid_source)
1313
1314    >>> im = sc.render()
1315
1316    """
1317
1318    def __init__(
1319        self, data_source, alpha=0.3, cmap=None, min_level=None, max_level=None
1320    ):
1321        self.data_source = data_source_or_all(data_source)
1322        corners = []
1323        levels = []
1324        for block, _mask in self.data_source.blocks:
1325            block_corners = np.array(
1326                [
1327                    [block.LeftEdge[0], block.LeftEdge[1], block.LeftEdge[2]],
1328                    [block.RightEdge[0], block.LeftEdge[1], block.LeftEdge[2]],
1329                    [block.RightEdge[0], block.RightEdge[1], block.LeftEdge[2]],
1330                    [block.LeftEdge[0], block.RightEdge[1], block.LeftEdge[2]],
1331                    [block.LeftEdge[0], block.LeftEdge[1], block.RightEdge[2]],
1332                    [block.RightEdge[0], block.LeftEdge[1], block.RightEdge[2]],
1333                    [block.RightEdge[0], block.RightEdge[1], block.RightEdge[2]],
1334                    [block.LeftEdge[0], block.RightEdge[1], block.RightEdge[2]],
1335                ],
1336                dtype="float64",
1337            )
1338            corners.append(block_corners)
1339            levels.append(block.Level)
1340        corners = np.dstack(corners)
1341        levels = np.array(levels)
1342        if cmap is None:
1343            cmap = ytcfg.get("yt", "default_colormap")
1344
1345        if max_level is not None:
1346            subset = levels <= max_level
1347            levels = levels[subset]
1348            corners = corners[:, :, subset]
1349        if min_level is not None:
1350            subset = levels >= min_level
1351            levels = levels[subset]
1352            corners = corners[:, :, subset]
1353
1354        colors = (
1355            apply_colormap(
1356                levels * 1.0,
1357                color_bounds=[0, self.data_source.ds.index.max_level],
1358                cmap_name=cmap,
1359            )[0, :, :]
1360            / 255.0
1361        )
1362        colors[:, 3] = alpha
1363
1364        order = [0, 1, 1, 2, 2, 3, 3, 0]
1365        order += [4, 5, 5, 6, 6, 7, 7, 4]
1366        order += [0, 4, 1, 5, 2, 6, 3, 7]
1367
1368        vertices = np.empty([corners.shape[2] * 2 * 12, 3])
1369        for i in range(3):
1370            vertices[:, i] = corners[order, i, ...].ravel(order="F")
1371        vertices = vertices.reshape((corners.shape[2] * 12, 2, 3))
1372
1373        super().__init__(vertices, colors, color_stride=24)
1374
1375
1376class CoordinateVectorSource(OpaqueSource):
1377    r"""Draw coordinate vectors on the scene.
1378
1379    This will draw a set of coordinate vectors on the camera image.  They
1380    will appear in the lower right of the image.
1381
1382    Parameters
1383    ----------
1384    colors: array-like of shape (3,4), optional
1385        The RGBA values to use to draw the x, y, and z vectors. The default is
1386        [[1, 0, 0, alpha], [0, 1, 0, alpha], [0, 0, 1, alpha]]  where ``alpha``
1387        is set by the parameter below. If ``colors`` is set then ``alpha`` is
1388        ignored.
1389    alpha : float, optional
1390        The opacity of the vectors.
1391
1392    Examples
1393    --------
1394
1395    >>> import yt
1396    >>> from yt.visualization.volume_rendering.api import \
1397    ...     CoordinateVectorSource
1398    >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
1399
1400    >>> im, sc = yt.volume_render(ds)
1401
1402    >>> coord_source = CoordinateVectorSource()
1403
1404    >>> sc.add_source(coord_source)
1405
1406    >>> im = sc.render()
1407
1408    """
1409
1410    def __init__(self, colors=None, alpha=1.0):
1411        super().__init__()
1412        # If colors aren't individually set, make black with full opacity
1413        if colors is None:
1414            colors = np.zeros((3, 4))
1415            colors[0, 0] = 1.0  # x is red
1416            colors[1, 1] = 1.0  # y is green
1417            colors[2, 2] = 1.0  # z is blue
1418            colors[:, 3] = alpha
1419        self.colors = colors
1420
1421    def render(self, camera, zbuffer=None):
1422        """Renders an image using the provided camera
1423
1424        Parameters
1425        ----------
1426        camera: :class:`yt.visualization.volume_rendering.camera.Camera`
1427            A volume rendering camera. Can be any type of camera.
1428        zbuffer: :class:`yt.visualization.volume_rendering.zbuffer_array.Zbuffer`
1429            A zbuffer array. This is used for opaque sources to determine the
1430            z position of the source relative to other sources. Only useful if
1431            you are manually calling render on multiple sources. Scene.render
1432            uses this internally.
1433
1434        Returns
1435        -------
1436        A :class:`yt.data_objects.image_array.ImageArray` containing
1437        the rendered image.
1438
1439        """
1440        camera.lens.setup_box_properties(camera)
1441        center = camera.focus
1442        # Get positions at the focus
1443        positions = np.zeros([6, 3])
1444        positions[:] = center
1445
1446        # Create vectors in the x,y,z directions
1447        for i in range(3):
1448            positions[2 * i + 1, i] += camera.width.in_units("code_length").d[i] / 16.0
1449
1450        # Project to the image plane
1451        px, py, dz = camera.lens.project_to_plane(camera, positions)
1452
1453        if len(px.shape) == 1:
1454            dpx = px[1::2] - px[::2]
1455            dpy = py[1::2] - py[::2]
1456
1457            # Set the center of the coordinates to be in the lower left of the image
1458            lpx = camera.resolution[0] / 8
1459            lpy = camera.resolution[1] - camera.resolution[1] / 8  # Upside-downsies
1460
1461            # Offset the pixels according to the projections above
1462            px[::2] = lpx
1463            px[1::2] = lpx + dpx
1464            py[::2] = lpy
1465            py[1::2] = lpy + dpy
1466            dz[:] = 0.0
1467        else:
1468            # For stereo-lens, two sets of pos for each eye are contained in px...pz
1469            dpx = px[:, 1::2] - px[:, ::2]
1470            dpy = py[:, 1::2] - py[:, ::2]
1471
1472            lpx = camera.resolution[0] / 16
1473            lpy = camera.resolution[1] - camera.resolution[1] / 8  # Upside-downsies
1474
1475            # Offset the pixels according to the projections above
1476            px[:, ::2] = lpx
1477            px[:, 1::2] = lpx + dpx
1478            px[1, :] += camera.resolution[0] / 2
1479            py[:, ::2] = lpy
1480            py[:, 1::2] = lpy + dpy
1481            dz[:, :] = 0.0
1482
1483        # Create a zbuffer if needed
1484        if zbuffer is None:
1485            empty = camera.lens.new_image(camera)
1486            z = np.empty(empty.shape[:2], dtype="float64")
1487            empty[:] = 0.0
1488            z[:] = np.inf
1489            zbuffer = ZBuffer(empty, z)
1490        else:
1491            empty = zbuffer.rgba
1492            z = zbuffer.z
1493
1494        # Draw the vectors
1495
1496        px = px.astype("int64")
1497        py = py.astype("int64")
1498
1499        if len(px.shape) == 1:
1500            zlines(empty, z, px, py, dz, self.colors.astype("float64"))
1501        else:
1502            # For stereo-lens, two sets of pos for each eye are contained
1503            # in px...pz
1504            zlines(
1505                empty, z, px[0, :], py[0, :], dz[0, :], self.colors.astype("float64")
1506            )
1507            zlines(
1508                empty, z, px[1, :], py[1, :], dz[1, :], self.colors.astype("float64")
1509            )
1510
1511        # Set the new zbuffer
1512        self.zbuffer = zbuffer
1513        return zbuffer
1514
1515    def __repr__(self):
1516        disp = "<Coordinates Source>"
1517        return disp
1518