1import numpy as np
2
3from matplotlib import _api
4from .axes_divider import make_axes_locatable, Size
5from .mpl_axes import Axes
6
7
8@_api.delete_parameter("3.3", "add_all")
9def make_rgb_axes(ax, pad=0.01, axes_class=None, add_all=True, **kwargs):
10    """
11    Parameters
12    ----------
13    pad : float
14        Fraction of the axes height.
15    """
16
17    divider = make_axes_locatable(ax)
18
19    pad_size = pad * Size.AxesY(ax)
20
21    xsize = ((1-2*pad)/3) * Size.AxesX(ax)
22    ysize = ((1-2*pad)/3) * Size.AxesY(ax)
23
24    divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
25    divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
26
27    ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
28
29    ax_rgb = []
30    if axes_class is None:
31        try:
32            axes_class = ax._axes_class
33        except AttributeError:
34            axes_class = type(ax)
35
36    for ny in [4, 2, 0]:
37        ax1 = axes_class(ax.get_figure(), ax.get_position(original=True),
38                         sharex=ax, sharey=ax, **kwargs)
39        locator = divider.new_locator(nx=2, ny=ny)
40        ax1.set_axes_locator(locator)
41        for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
42            t.set_visible(False)
43        try:
44            for axis in ax1.axis.values():
45                axis.major_ticklabels.set_visible(False)
46        except AttributeError:
47            pass
48
49        ax_rgb.append(ax1)
50
51    if add_all:
52        fig = ax.get_figure()
53        for ax1 in ax_rgb:
54            fig.add_axes(ax1)
55
56    return ax_rgb
57
58
59@_api.deprecated("3.3", alternative="ax.imshow(np.dstack([r, g, b]))")
60def imshow_rgb(ax, r, g, b, **kwargs):
61    return ax.imshow(np.dstack([r, g, b]), **kwargs)
62
63
64class RGBAxes:
65    """
66    4-panel imshow (RGB, R, G, B).
67
68    Layout:
69
70        +---------------+-----+
71        |               |  R  |
72        +               +-----+
73        |      RGB      |  G  |
74        +               +-----+
75        |               |  B  |
76        +---------------+-----+
77
78    Subclasses can override the ``_defaultAxesClass`` attribute.
79
80    Attributes
81    ----------
82    RGB : ``_defaultAxesClass``
83        The axes object for the three-channel imshow.
84    R : ``_defaultAxesClass``
85        The axes object for the red channel imshow.
86    G : ``_defaultAxesClass``
87        The axes object for the green channel imshow.
88    B : ``_defaultAxesClass``
89        The axes object for the blue channel imshow.
90    """
91
92    _defaultAxesClass = Axes
93
94    @_api.delete_parameter("3.3", "add_all")
95    def __init__(self, *args, pad=0, add_all=True, **kwargs):
96        """
97        Parameters
98        ----------
99        pad : float, default: 0
100            fraction of the axes height to put as padding.
101        add_all : bool, default: True
102            Whether to add the {rgb, r, g, b} axes to the figure.
103            This parameter is deprecated.
104        axes_class : matplotlib.axes.Axes
105
106        *args
107            Unpacked into axes_class() init for RGB
108        **kwargs
109            Unpacked into axes_class() init for RGB, R, G, B axes
110        """
111        axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
112        self.RGB = ax = axes_class(*args, **kwargs)
113        if add_all:
114            ax.get_figure().add_axes(ax)
115        else:
116            kwargs["add_all"] = add_all  # only show deprecation in that case
117        self.R, self.G, self.B = make_rgb_axes(
118            ax, pad=pad, axes_class=axes_class, **kwargs)
119        # Set the line color and ticks for the axes.
120        for ax1 in [self.RGB, self.R, self.G, self.B]:
121            ax1.axis[:].line.set_color("w")
122            ax1.axis[:].major_ticks.set_markeredgecolor("w")
123
124    @_api.deprecated("3.3")
125    def add_RGB_to_figure(self):
126        """Add red, green and blue axes to the RGB composite's axes figure."""
127        self.RGB.get_figure().add_axes(self.R)
128        self.RGB.get_figure().add_axes(self.G)
129        self.RGB.get_figure().add_axes(self.B)
130
131    def imshow_rgb(self, r, g, b, **kwargs):
132        """
133        Create the four images {rgb, r, g, b}.
134
135        Parameters
136        ----------
137        r, g, b : array-like
138            The red, green, and blue arrays.
139        kwargs : imshow kwargs
140            kwargs get unpacked into the imshow calls for the four images.
141
142        Returns
143        -------
144        rgb : matplotlib.image.AxesImage
145        r : matplotlib.image.AxesImage
146        g : matplotlib.image.AxesImage
147        b : matplotlib.image.AxesImage
148        """
149        if not (r.shape == g.shape == b.shape):
150            raise ValueError(
151                f'Input shapes ({r.shape}, {g.shape}, {b.shape}) do not match')
152        RGB = np.dstack([r, g, b])
153        R = np.zeros_like(RGB)
154        R[:, :, 0] = r
155        G = np.zeros_like(RGB)
156        G[:, :, 1] = g
157        B = np.zeros_like(RGB)
158        B[:, :, 2] = b
159        im_rgb = self.RGB.imshow(RGB, **kwargs)
160        im_r = self.R.imshow(R, **kwargs)
161        im_g = self.G.imshow(G, **kwargs)
162        im_b = self.B.imshow(B, **kwargs)
163        return im_rgb, im_r, im_g, im_b
164
165
166@_api.deprecated("3.3", alternative="RGBAxes")
167class RGBAxesBase(RGBAxes):
168    pass
169