Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aaronreidsmith / matplotlib   python

Repository URL to install this package:

Version: 3.1.2 

/ axes_grid1 / parasite_axes.py

import functools

from matplotlib import artist as martist, cbook, transforms as mtransforms
from matplotlib.axes import subplot_class_factory
from matplotlib.transforms import Bbox
from .mpl_axes import Axes

import numpy as np


class ParasiteAxesBase:

    def get_images_artists(self):
        artists = {a for a in self.get_children() if a.get_visible()}
        images = {a for a in self.images if a.get_visible()}

        return list(images), list(artists - images)

    def __init__(self, parent_axes, **kwargs):
        self._parent_axes = parent_axes
        kwargs["frameon"] = False
        super().__init__(parent_axes.figure, parent_axes._position, **kwargs)

    def cla(self):
        super().cla()

        martist.setp(self.get_children(), visible=False)
        self._get_lines = self._parent_axes._get_lines

        # In mpl's Axes, zorders of x- and y-axis are originally set
        # within Axes.draw().
        if self._axisbelow:
            self.xaxis.set_zorder(0.5)
            self.yaxis.set_zorder(0.5)
        else:
            self.xaxis.set_zorder(2.5)
            self.yaxis.set_zorder(2.5)

    def pick(self, mouseevent):
        # This most likely goes to Artist.pick (depending on axes_class given
        # to the factory), which only handles pick events registered on the
        # axes associated with each child:
        super().pick(mouseevent)
        # But parasite axes are additionally given pick events from their host
        # axes (cf. HostAxesBase.pick), which we handle here:
        for a in self.get_children():
            if (hasattr(mouseevent.inaxes, "parasites")
                    and self in mouseevent.inaxes.parasites):
                a.pick(mouseevent)


@functools.lru_cache(None)
def parasite_axes_class_factory(axes_class=None):
    if axes_class is None:
        axes_class = Axes

    return type("%sParasite" % axes_class.__name__,
                (ParasiteAxesBase, axes_class), {})


ParasiteAxes = parasite_axes_class_factory()


class ParasiteAxesAuxTransBase:
    def __init__(self, parent_axes, aux_transform, viewlim_mode=None,
                 **kwargs):
        self.transAux = aux_transform
        self.set_viewlim_mode(viewlim_mode)
        super().__init__(parent_axes, **kwargs)

    def _set_lim_and_transforms(self):

        self.transAxes = self._parent_axes.transAxes

        self.transData = \
            self.transAux + \
            self._parent_axes.transData

        self._xaxis_transform = mtransforms.blended_transform_factory(
                self.transData, self.transAxes)
        self._yaxis_transform = mtransforms.blended_transform_factory(
                self.transAxes, self.transData)

    def set_viewlim_mode(self, mode):
        cbook._check_in_list([None, "equal", "transform"], mode=mode)
        self._viewlim_mode = mode

    def get_viewlim_mode(self):
        return self._viewlim_mode

    def update_viewlim(self):
        viewlim = self._parent_axes.viewLim.frozen()
        mode = self.get_viewlim_mode()
        if mode is None:
            pass
        elif mode == "equal":
            self.axes.viewLim.set(viewlim)
        elif mode == "transform":
            self.axes.viewLim.set(
                viewlim.transformed(self.transAux.inverted()))
        else:
            cbook._check_in_list([None, "equal", "transform"], mode=mode)

    def _pcolor(self, super_pcolor, *XYC, **kwargs):
        if len(XYC) == 1:
            C = XYC[0]
            ny, nx = C.shape

            gx = np.arange(-0.5, nx)
            gy = np.arange(-0.5, ny)

            X, Y = np.meshgrid(gx, gy)
        else:
            X, Y, C = XYC

        if "transform" in kwargs:
            mesh = super_pcolor(X, Y, C, **kwargs)
        else:
            orig_shape = X.shape
            xyt = np.column_stack([X.flat, Y.flat])
            wxy = self.transAux.transform(xyt)
            gx = wxy[:, 0].reshape(orig_shape)
            gy = wxy[:, 1].reshape(orig_shape)
            mesh = super_pcolor(gx, gy, C, **kwargs)
            mesh.set_transform(self._parent_axes.transData)

        return mesh

    def pcolormesh(self, *XYC, **kwargs):
        return self._pcolor(super().pcolormesh, *XYC, **kwargs)

    def pcolor(self, *XYC, **kwargs):
        return self._pcolor(super().pcolor, *XYC, **kwargs)

    def _contour(self, super_contour, *XYCL, **kwargs):

        if len(XYCL) <= 2:
            C = XYCL[0]
            ny, nx = C.shape

            gx = np.arange(0., nx)
            gy = np.arange(0., ny)

            X, Y = np.meshgrid(gx, gy)
            CL = XYCL
        else:
            X, Y = XYCL[:2]
            CL = XYCL[2:]

        if "transform" in kwargs:
            cont = super_contour(X, Y, *CL, **kwargs)
        else:
            orig_shape = X.shape
            xyt = np.column_stack([X.flat, Y.flat])
            wxy = self.transAux.transform(xyt)
            gx = wxy[:, 0].reshape(orig_shape)
            gy = wxy[:, 1].reshape(orig_shape)
            cont = super_contour(gx, gy, *CL, **kwargs)
            for c in cont.collections:
                c.set_transform(self._parent_axes.transData)

        return cont

    def contour(self, *XYCL, **kwargs):
        return self._contour(super().contour, *XYCL, **kwargs)

    def contourf(self, *XYCL, **kwargs):
        return self._contour(super().contourf, *XYCL, **kwargs)

    def apply_aspect(self, position=None):
        self.update_viewlim()
        super().apply_aspect()


@functools.lru_cache(None)
def parasite_axes_auxtrans_class_factory(axes_class=None):
    if axes_class is None:
        parasite_axes_class = ParasiteAxes
    elif not issubclass(axes_class, ParasiteAxesBase):
        parasite_axes_class = parasite_axes_class_factory(axes_class)
    else:
        parasite_axes_class = axes_class
    return type("%sParasiteAuxTrans" % parasite_axes_class.__name__,
                (ParasiteAxesAuxTransBase, parasite_axes_class),
                {'name': 'parasite_axes'})


ParasiteAxesAuxTrans = parasite_axes_auxtrans_class_factory(
    axes_class=ParasiteAxes)


class HostAxesBase:
    def __init__(self, *args, **kwargs):
        self.parasites = []
        super().__init__(*args, **kwargs)

    def get_aux_axes(self, tr, viewlim_mode="equal", axes_class=None):
        parasite_axes_class = parasite_axes_auxtrans_class_factory(axes_class)
        ax2 = parasite_axes_class(self, tr, viewlim_mode)
        # note that ax2.transData == tr + ax1.transData
        # Anything you draw in ax2 will match the ticks and grids of ax1.
        self.parasites.append(ax2)
        ax2._remove_method = self.parasites.remove
        return ax2

    def _get_legend_handles(self, legend_handler_map=None):
        all_handles = super()._get_legend_handles()
        for ax in self.parasites:
            all_handles.extend(ax._get_legend_handles(legend_handler_map))
        return all_handles

    def draw(self, renderer):

        orig_artists = list(self.artists)
        orig_images = list(self.images)

        if hasattr(self, "get_axes_locator"):
            locator = self.get_axes_locator()
            if locator:
                pos = locator(self, renderer)
                self.set_position(pos, which="active")
                self.apply_aspect(pos)
            else:
                self.apply_aspect()
        else:
            self.apply_aspect()

        rect = self.get_position()

        for ax in self.parasites:
            ax.apply_aspect(rect)
            images, artists = ax.get_images_artists()
            self.images.extend(images)
            self.artists.extend(artists)

        super().draw(renderer)
        self.artists = orig_artists
        self.images = orig_images

    def cla(self):
        for ax in self.parasites:
            ax.cla()
        super().cla()

    def pick(self, mouseevent):
        super().pick(mouseevent)
        # Also pass pick events on to parasite axes and, in turn, their
        # children (cf. ParasiteAxesBase.pick)
        for a in self.parasites:
            a.pick(mouseevent)

    def twinx(self, axes_class=None):
        """
        Create a twin of Axes with a shared x-axis but independent y-axis.

        The y-axis of self will have ticks on the left and the returned axes
        will have ticks on the right.
        """
        if axes_class is None:
            axes_class = self._get_base_axes()

        parasite_axes_class = parasite_axes_class_factory(axes_class)

        ax2 = parasite_axes_class(self, sharex=self, frameon=False)
        self.parasites.append(ax2)
        ax2._remove_method = self._remove_twinx

        self.axis["right"].set_visible(False)

        ax2.axis["right"].set_visible(True)
        ax2.axis["left", "top", "bottom"].set_visible(False)

        return ax2

    def _remove_twinx(self, ax):
        self.parasites.remove(ax)
        self.axis["right"].set_visible(True)
        self.axis["right"].toggle(ticklabels=False, label=False)

    def twiny(self, axes_class=None):
        """
        Create a twin of Axes with a shared y-axis but independent x-axis.

        The x-axis of self will have ticks on the bottom and the returned axes
        will have ticks on the top.
        """
        if axes_class is None:
            axes_class = self._get_base_axes()

        parasite_axes_class = parasite_axes_class_factory(axes_class)

        ax2 = parasite_axes_class(self, sharey=self, frameon=False)
        self.parasites.append(ax2)
        ax2._remove_method = self._remove_twiny

        self.axis["top"].set_visible(False)

        ax2.axis["top"].set_visible(True)
        ax2.axis["left", "right", "bottom"].set_visible(False)

        return ax2

    def _remove_twiny(self, ax):
        self.parasites.remove(ax)
        self.axis["top"].set_visible(True)
        self.axis["top"].toggle(ticklabels=False, label=False)

    def twin(self, aux_trans=None, axes_class=None):
        """
        Create a twin of Axes with no shared axis.

        While self will have ticks on the left and bottom axis, the returned
        axes will have ticks on the top and right axis.
        """
        if axes_class is None:
            axes_class = self._get_base_axes()

        parasite_axes_auxtrans_class = \
            parasite_axes_auxtrans_class_factory(axes_class)

        if aux_trans is None:
            ax2 = parasite_axes_auxtrans_class(
                self, mtransforms.IdentityTransform(), viewlim_mode="equal")
        else:
            ax2 = parasite_axes_auxtrans_class(
                self, aux_trans, viewlim_mode="transform")
        self.parasites.append(ax2)
        ax2._remove_method = self.parasites.remove

        self.axis["top", "right"].set_visible(False)

        ax2.axis["top", "right"].set_visible(True)
        ax2.axis["left", "bottom"].set_visible(False)

        def _remove_method(h):
            self.parasites.remove(h)
            self.axis["top", "right"].set_visible(True)
            self.axis["top", "right"].toggle(ticklabels=False, label=False)
        ax2._remove_method = _remove_method

        return ax2

    def get_tightbbox(self, renderer, call_axes_locator=True,
                      bbox_extra_artists=None):
        bbs = [ax.get_tightbbox(renderer, call_axes_locator=call_axes_locator)
               for ax in self.parasites]
        bbs.append(super().get_tightbbox(renderer,
                call_axes_locator=call_axes_locator,
                bbox_extra_artists=bbox_extra_artists))
        return Bbox.union([b for b in bbs if b.width != 0 or b.height != 0])


@functools.lru_cache(None)
def host_axes_class_factory(axes_class=None):
    if axes_class is None:
        axes_class = Axes

    def _get_base_axes(self):
        return axes_class

    return type("%sHostAxes" % axes_class.__name__,
                (HostAxesBase, axes_class),
                {'_get_base_axes': _get_base_axes})


def host_subplot_class_factory(axes_class):
    host_axes_class = host_axes_class_factory(axes_class=axes_class)
    subplot_host_class = subplot_class_factory(host_axes_class)
    return subplot_host_class


HostAxes = host_axes_class_factory(axes_class=Axes)
SubplotHost = subplot_class_factory(HostAxes)


def host_axes(*args, axes_class=None, figure=None, **kwargs):
    """
    Create axes that can act as a hosts to parasitic axes.

    Parameters
    ----------
    figure : `matplotlib.figure.Figure`
        Figure to which the axes will be added. Defaults to the current figure
        `pyplot.gcf()`.

    *args, **kwargs
        Will be passed on to the underlying ``Axes`` object creation.
    """
    import matplotlib.pyplot as plt
    host_axes_class = host_axes_class_factory(axes_class)
    if figure is None:
        figure = plt.gcf()
    ax = host_axes_class(figure, *args, **kwargs)
    figure.add_axes(ax)
    plt.draw_if_interactive()
    return ax


def host_subplot(*args, axes_class=None, figure=None, **kwargs):
    """
    Create a subplot that can act as a host to parasitic axes.

    Parameters
    ----------
    figure : `matplotlib.figure.Figure`
        Figure to which the subplot will be added. Defaults to the current
        figure `pyplot.gcf()`.

    *args, **kwargs
        Will be passed on to the underlying ``Axes`` object creation.
    """
    import matplotlib.pyplot as plt
    host_subplot_class = host_subplot_class_factory(axes_class)
    if figure is None:
        figure = plt.gcf()
    ax = host_subplot_class(figure, *args, **kwargs)
    figure.add_subplot(ax)
    plt.draw_if_interactive()
    return ax