1import numpy as np 2 3from matplotlib import _api 4from .axes_divider import make_axes_locatable, Size 5from .mpl_axes import Axes 6 7 8@_api.delete_parameter("3.3", "add_all") 9def make_rgb_axes(ax, pad=0.01, axes_class=None, add_all=True, **kwargs): 10 """ 11 Parameters 12 ---------- 13 pad : float 14 Fraction of the axes height. 15 """ 16 17 divider = make_axes_locatable(ax) 18 19 pad_size = pad * Size.AxesY(ax) 20 21 xsize = ((1-2*pad)/3) * Size.AxesX(ax) 22 ysize = ((1-2*pad)/3) * Size.AxesY(ax) 23 24 divider.set_horizontal([Size.AxesX(ax), pad_size, xsize]) 25 divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize]) 26 27 ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1)) 28 29 ax_rgb = [] 30 if axes_class is None: 31 try: 32 axes_class = ax._axes_class 33 except AttributeError: 34 axes_class = type(ax) 35 36 for ny in [4, 2, 0]: 37 ax1 = axes_class(ax.get_figure(), ax.get_position(original=True), 38 sharex=ax, sharey=ax, **kwargs) 39 locator = divider.new_locator(nx=2, ny=ny) 40 ax1.set_axes_locator(locator) 41 for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels(): 42 t.set_visible(False) 43 try: 44 for axis in ax1.axis.values(): 45 axis.major_ticklabels.set_visible(False) 46 except AttributeError: 47 pass 48 49 ax_rgb.append(ax1) 50 51 if add_all: 52 fig = ax.get_figure() 53 for ax1 in ax_rgb: 54 fig.add_axes(ax1) 55 56 return ax_rgb 57 58 59@_api.deprecated("3.3", alternative="ax.imshow(np.dstack([r, g, b]))") 60def imshow_rgb(ax, r, g, b, **kwargs): 61 return ax.imshow(np.dstack([r, g, b]), **kwargs) 62 63 64class RGBAxes: 65 """ 66 4-panel imshow (RGB, R, G, B). 67 68 Layout: 69 70 +---------------+-----+ 71 | | R | 72 + +-----+ 73 | RGB | G | 74 + +-----+ 75 | | B | 76 +---------------+-----+ 77 78 Subclasses can override the ``_defaultAxesClass`` attribute. 79 80 Attributes 81 ---------- 82 RGB : ``_defaultAxesClass`` 83 The axes object for the three-channel imshow. 84 R : ``_defaultAxesClass`` 85 The axes object for the red channel imshow. 86 G : ``_defaultAxesClass`` 87 The axes object for the green channel imshow. 88 B : ``_defaultAxesClass`` 89 The axes object for the blue channel imshow. 90 """ 91 92 _defaultAxesClass = Axes 93 94 @_api.delete_parameter("3.3", "add_all") 95 def __init__(self, *args, pad=0, add_all=True, **kwargs): 96 """ 97 Parameters 98 ---------- 99 pad : float, default: 0 100 fraction of the axes height to put as padding. 101 add_all : bool, default: True 102 Whether to add the {rgb, r, g, b} axes to the figure. 103 This parameter is deprecated. 104 axes_class : matplotlib.axes.Axes 105 106 *args 107 Unpacked into axes_class() init for RGB 108 **kwargs 109 Unpacked into axes_class() init for RGB, R, G, B axes 110 """ 111 axes_class = kwargs.pop("axes_class", self._defaultAxesClass) 112 self.RGB = ax = axes_class(*args, **kwargs) 113 if add_all: 114 ax.get_figure().add_axes(ax) 115 else: 116 kwargs["add_all"] = add_all # only show deprecation in that case 117 self.R, self.G, self.B = make_rgb_axes( 118 ax, pad=pad, axes_class=axes_class, **kwargs) 119 # Set the line color and ticks for the axes. 120 for ax1 in [self.RGB, self.R, self.G, self.B]: 121 ax1.axis[:].line.set_color("w") 122 ax1.axis[:].major_ticks.set_markeredgecolor("w") 123 124 @_api.deprecated("3.3") 125 def add_RGB_to_figure(self): 126 """Add red, green and blue axes to the RGB composite's axes figure.""" 127 self.RGB.get_figure().add_axes(self.R) 128 self.RGB.get_figure().add_axes(self.G) 129 self.RGB.get_figure().add_axes(self.B) 130 131 def imshow_rgb(self, r, g, b, **kwargs): 132 """ 133 Create the four images {rgb, r, g, b}. 134 135 Parameters 136 ---------- 137 r, g, b : array-like 138 The red, green, and blue arrays. 139 kwargs : imshow kwargs 140 kwargs get unpacked into the imshow calls for the four images. 141 142 Returns 143 ------- 144 rgb : matplotlib.image.AxesImage 145 r : matplotlib.image.AxesImage 146 g : matplotlib.image.AxesImage 147 b : matplotlib.image.AxesImage 148 """ 149 if not (r.shape == g.shape == b.shape): 150 raise ValueError( 151 f'Input shapes ({r.shape}, {g.shape}, {b.shape}) do not match') 152 RGB = np.dstack([r, g, b]) 153 R = np.zeros_like(RGB) 154 R[:, :, 0] = r 155 G = np.zeros_like(RGB) 156 G[:, :, 1] = g 157 B = np.zeros_like(RGB) 158 B[:, :, 2] = b 159 im_rgb = self.RGB.imshow(RGB, **kwargs) 160 im_r = self.R.imshow(R, **kwargs) 161 im_g = self.G.imshow(G, **kwargs) 162 im_b = self.B.imshow(B, **kwargs) 163 return im_rgb, im_r, im_g, im_b 164 165 166@_api.deprecated("3.3", alternative="RGBAxes") 167class RGBAxesBase(RGBAxes): 168 pass 169