158 lines
5.1 KiB
Python
158 lines
5.1 KiB
Python
|
from types import MethodType
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from .axes_divider import make_axes_locatable, Size
|
||
|
from .mpl_axes import Axes, SimpleAxisArtist
|
||
|
|
||
|
|
||
|
def make_rgb_axes(ax, pad=0.01, axes_class=None, **kwargs):
|
||
|
"""
|
||
|
Parameters
|
||
|
----------
|
||
|
ax : `~matplotlib.axes.Axes`
|
||
|
Axes instance to create the RGB Axes in.
|
||
|
pad : float, optional
|
||
|
Fraction of the Axes height to pad.
|
||
|
axes_class : `matplotlib.axes.Axes` or None, optional
|
||
|
Axes class to use for the R, G, and B Axes. If None, use
|
||
|
the same class as *ax*.
|
||
|
**kwargs
|
||
|
Forwarded to *axes_class* init for the R, G, and B Axes.
|
||
|
"""
|
||
|
|
||
|
divider = make_axes_locatable(ax)
|
||
|
|
||
|
pad_size = pad * Size.AxesY(ax)
|
||
|
|
||
|
xsize = ((1-2*pad)/3) * Size.AxesX(ax)
|
||
|
ysize = ((1-2*pad)/3) * Size.AxesY(ax)
|
||
|
|
||
|
divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
|
||
|
divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
|
||
|
|
||
|
ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
|
||
|
|
||
|
ax_rgb = []
|
||
|
if axes_class is None:
|
||
|
axes_class = type(ax)
|
||
|
|
||
|
for ny in [4, 2, 0]:
|
||
|
ax1 = axes_class(ax.get_figure(), ax.get_position(original=True),
|
||
|
sharex=ax, sharey=ax, **kwargs)
|
||
|
locator = divider.new_locator(nx=2, ny=ny)
|
||
|
ax1.set_axes_locator(locator)
|
||
|
for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
|
||
|
t.set_visible(False)
|
||
|
try:
|
||
|
for axis in ax1.axis.values():
|
||
|
axis.major_ticklabels.set_visible(False)
|
||
|
except AttributeError:
|
||
|
pass
|
||
|
|
||
|
ax_rgb.append(ax1)
|
||
|
|
||
|
fig = ax.get_figure()
|
||
|
for ax1 in ax_rgb:
|
||
|
fig.add_axes(ax1)
|
||
|
|
||
|
return ax_rgb
|
||
|
|
||
|
|
||
|
class RGBAxes:
|
||
|
"""
|
||
|
4-panel `~.Axes.imshow` (RGB, R, G, B).
|
||
|
|
||
|
Layout::
|
||
|
|
||
|
┌───────────────┬─────┐
|
||
|
│ │ R │
|
||
|
│ ├─────┤
|
||
|
│ RGB │ G │
|
||
|
│ ├─────┤
|
||
|
│ │ B │
|
||
|
└───────────────┴─────┘
|
||
|
|
||
|
Subclasses can override the ``_defaultAxesClass`` attribute.
|
||
|
By default RGBAxes uses `.mpl_axes.Axes`.
|
||
|
|
||
|
Attributes
|
||
|
----------
|
||
|
RGB : ``_defaultAxesClass``
|
||
|
The Axes object for the three-channel `~.Axes.imshow`.
|
||
|
R : ``_defaultAxesClass``
|
||
|
The Axes object for the red channel `~.Axes.imshow`.
|
||
|
G : ``_defaultAxesClass``
|
||
|
The Axes object for the green channel `~.Axes.imshow`.
|
||
|
B : ``_defaultAxesClass``
|
||
|
The Axes object for the blue channel `~.Axes.imshow`.
|
||
|
"""
|
||
|
|
||
|
_defaultAxesClass = Axes
|
||
|
|
||
|
def __init__(self, *args, pad=0, **kwargs):
|
||
|
"""
|
||
|
Parameters
|
||
|
----------
|
||
|
pad : float, default: 0
|
||
|
Fraction of the Axes height to put as padding.
|
||
|
axes_class : `~matplotlib.axes.Axes`
|
||
|
Axes class to use. If not provided, ``_defaultAxesClass`` is used.
|
||
|
*args
|
||
|
Forwarded to *axes_class* init for the RGB Axes
|
||
|
**kwargs
|
||
|
Forwarded to *axes_class* init for the RGB, R, G, and B Axes
|
||
|
"""
|
||
|
axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
|
||
|
self.RGB = ax = axes_class(*args, **kwargs)
|
||
|
ax.get_figure().add_axes(ax)
|
||
|
self.R, self.G, self.B = make_rgb_axes(
|
||
|
ax, pad=pad, axes_class=axes_class, **kwargs)
|
||
|
# Set the line color and ticks for the axes.
|
||
|
for ax1 in [self.RGB, self.R, self.G, self.B]:
|
||
|
if isinstance(ax1.axis, MethodType):
|
||
|
ad = Axes.AxisDict(self)
|
||
|
ad.update(
|
||
|
bottom=SimpleAxisArtist(ax1.xaxis, 1, ax1.spines["bottom"]),
|
||
|
top=SimpleAxisArtist(ax1.xaxis, 2, ax1.spines["top"]),
|
||
|
left=SimpleAxisArtist(ax1.yaxis, 1, ax1.spines["left"]),
|
||
|
right=SimpleAxisArtist(ax1.yaxis, 2, ax1.spines["right"]))
|
||
|
else:
|
||
|
ad = ax1.axis
|
||
|
ad[:].line.set_color("w")
|
||
|
ad[:].major_ticks.set_markeredgecolor("w")
|
||
|
|
||
|
def imshow_rgb(self, r, g, b, **kwargs):
|
||
|
"""
|
||
|
Create the four images {rgb, r, g, b}.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
r, g, b : array-like
|
||
|
The red, green, and blue arrays.
|
||
|
**kwargs
|
||
|
Forwarded to `~.Axes.imshow` calls for the four images.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
rgb : `~matplotlib.image.AxesImage`
|
||
|
r : `~matplotlib.image.AxesImage`
|
||
|
g : `~matplotlib.image.AxesImage`
|
||
|
b : `~matplotlib.image.AxesImage`
|
||
|
"""
|
||
|
if not (r.shape == g.shape == b.shape):
|
||
|
raise ValueError(
|
||
|
f'Input shapes ({r.shape}, {g.shape}, {b.shape}) do not match')
|
||
|
RGB = np.dstack([r, g, b])
|
||
|
R = np.zeros_like(RGB)
|
||
|
R[:, :, 0] = r
|
||
|
G = np.zeros_like(RGB)
|
||
|
G[:, :, 1] = g
|
||
|
B = np.zeros_like(RGB)
|
||
|
B[:, :, 2] = b
|
||
|
im_rgb = self.RGB.imshow(RGB, **kwargs)
|
||
|
im_r = self.R.imshow(R, **kwargs)
|
||
|
im_g = self.G.imshow(G, **kwargs)
|
||
|
im_b = self.B.imshow(B, **kwargs)
|
||
|
return im_rgb, im_r, im_g, im_b
|