Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aaronreidsmith / matplotlib   python

Repository URL to install this package:

Version: 3.1.2 

/ axes_grid1 / axes_rgb.py

import numpy as np

from .axes_divider import make_axes_locatable, Size
from .mpl_axes import Axes


def make_rgb_axes(ax, pad=0.01, axes_class=None, add_all=True):
    """
    Parameters
    ----------
    pad : float
        Fraction of the axes height.
    """

    divider = make_axes_locatable(ax)

    pad_size = Size.Fraction(pad, Size.AxesY(ax))

    xsize = Size.Fraction((1.-2.*pad)/3., Size.AxesX(ax))
    ysize = Size.Fraction((1.-2.*pad)/3., Size.AxesY(ax))

    divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
    divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])

    ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))

    ax_rgb = []
    if axes_class is None:
        try:
            axes_class = ax._axes_class
        except AttributeError:
            axes_class = type(ax)

    for ny in [4, 2, 0]:
        ax1 = axes_class(ax.get_figure(),
                         ax.get_position(original=True),
                         sharex=ax, sharey=ax)
        locator = divider.new_locator(nx=2, ny=ny)
        ax1.set_axes_locator(locator)
        for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
            t.set_visible(False)
        try:
            for axis in ax1.axis.values():
                axis.major_ticklabels.set_visible(False)
        except AttributeError:
            pass

        ax_rgb.append(ax1)

    if add_all:
        fig = ax.get_figure()
        for ax1 in ax_rgb:
            fig.add_axes(ax1)

    return ax_rgb


def imshow_rgb(ax, r, g, b, **kwargs):
    ny, nx = r.shape
    R = np.zeros([ny, nx, 3], dtype="d")
    R[:, :, 0] = r
    G = np.zeros_like(R)
    G[:, :, 1] = g
    B = np.zeros_like(R)
    B[:, :, 2] = b

    RGB = R + G + B

    im_rgb = ax.imshow(RGB, **kwargs)

    return im_rgb


class RGBAxesBase(object):
    """base class for a 4-panel imshow (RGB, R, G, B)

    Layout:
    +---------------+-----+
    |               |  R  |
    +               +-----+
    |      RGB      |  G  |
    +               +-----+
    |               |  B  |
    +---------------+-----+

    Attributes
    ----------
    _defaultAxesClass : matplotlib.axes.Axes
        defaults to 'Axes' in RGBAxes child class.
        No default in abstract base class
    RGB : _defaultAxesClass
        The axes object for the three-channel imshow
    R : _defaultAxesClass
        The axes object for the red channel imshow
    G : _defaultAxesClass
        The axes object for the green channel imshow
    B : _defaultAxesClass
        The axes object for the blue channel imshow
    """
    def __init__(self, *args, pad=0, add_all=True, **kwargs):
        """
        Parameters
        ----------
        pad : float
            fraction of the axes height to put as padding.
            defaults to 0.0
        add_all : bool
            True: Add the {rgb, r, g, b} axes to the figure
            defaults to True.
        axes_class : matplotlib.axes.Axes

        *args
            Unpacked into axes_class() init for RGB
        **kwargs
            Unpacked into axes_class() init for RGB, R, G, B axes
        """
        try:
            axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
        except AttributeError:
            raise AttributeError(
                'A subclass of RGBAxesBase must have a _defaultAxesClass '
                'attribute. If you are not sure which axes class to use, '
                'consider using mpl_toolkits.axes_grid1.mpl_axes.Axes.'
            )

        ax = axes_class(*args, **kwargs)

        divider = make_axes_locatable(ax)

        pad_size = Size.Fraction(pad, Size.AxesY(ax))

        xsize = Size.Fraction((1.-2.*pad)/3., Size.AxesX(ax))
        ysize = Size.Fraction((1.-2.*pad)/3., Size.AxesY(ax))

        divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
        divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])

        ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))

        ax_rgb = []
        for ny in [4, 2, 0]:
            ax1 = axes_class(ax.get_figure(),
                             ax.get_position(original=True),
                             sharex=ax, sharey=ax, **kwargs)
            locator = divider.new_locator(nx=2, ny=ny)
            ax1.set_axes_locator(locator)
            ax1.axis[:].toggle(ticklabels=False)
            ax_rgb.append(ax1)

        self.RGB = ax
        self.R, self.G, self.B = ax_rgb

        if add_all:
            fig = ax.get_figure()
            fig.add_axes(ax)
            self.add_RGB_to_figure()

        self._config_axes()

    def _config_axes(self, line_color='w', marker_edge_color='w'):
        """Set the line color and ticks for the axes

        Parameters
        ----------
        line_color : any matplotlib color
        marker_edge_color : any matplotlib color
        """
        for ax1 in [self.RGB, self.R, self.G, self.B]:
            ax1.axis[:].line.set_color(line_color)
            ax1.axis[:].major_ticks.set_markeredgecolor(marker_edge_color)

    def add_RGB_to_figure(self):
        """Add the red, green and blue axes to the RGB composite's axes figure
        """
        self.RGB.get_figure().add_axes(self.R)
        self.RGB.get_figure().add_axes(self.G)
        self.RGB.get_figure().add_axes(self.B)

    def imshow_rgb(self, r, g, b, **kwargs):
        """Create the four images {rgb, r, g, b}

        Parameters
        ----------
        r : array-like
            The red array
        g : array-like
            The green array
        b : array-like
            The blue array
        kwargs : imshow kwargs
            kwargs get unpacked into the imshow calls for the four images

        Returns
        -------
        rgb : matplotlib.image.AxesImage
        r : matplotlib.image.AxesImage
        g : matplotlib.image.AxesImage
        b : matplotlib.image.AxesImage
        """
        if not (r.shape == g.shape == b.shape):
            raise ValueError('Input shapes do not match.'
                             '\nr.shape = {}'
                             '\ng.shape = {}'
                             '\nb.shape = {}'
                             .format(r.shape, g.shape, b.shape))
        RGB = np.dstack([r, g, b])
        R = np.zeros_like(RGB)
        R[:, :, 0] = r
        G = np.zeros_like(RGB)
        G[:, :, 1] = g
        B = np.zeros_like(RGB)
        B[:, :, 2] = b

        im_rgb = self.RGB.imshow(RGB, **kwargs)
        im_r = self.R.imshow(R, **kwargs)
        im_g = self.G.imshow(G, **kwargs)
        im_b = self.B.imshow(B, **kwargs)

        return im_rgb, im_r, im_g, im_b


class RGBAxes(RGBAxesBase):
    _defaultAxesClass = Axes