Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
matplotlib / axes_grid1 / axes_grid.py
Size: Mime:
import matplotlib.cbook as cbook

import matplotlib.pyplot as plt
import matplotlib.axes as maxes
#import matplotlib.colorbar as mcolorbar
import colorbar as mcolorbar
import matplotlib as mpl
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import matplotlib.ticker as ticker

from matplotlib.gridspec import SubplotSpec, GridSpec

from axes_divider import Size, SubplotDivider, LocatableAxes, Divider

#import numpy as np

def _tick_only(ax, bottom_on, left_on):
    bottom_off = not bottom_on
    left_off = not left_on
    # [l.set_visible(bottom_off) for l in ax.get_xticklabels()]
    # [l.set_visible(left_off) for l in ax.get_yticklabels()]
    # ax.xaxis.label.set_visible(bottom_off)
    # ax.yaxis.label.set_visible(left_off)
    ax.axis["bottom"].toggle(ticklabels=bottom_off, label=bottom_off)
    ax.axis["left"].toggle(ticklabels=left_off, label=left_off)

class Colorbar(mcolorbar.Colorbar):
    def _config_axes_deprecated(self, X, Y):
        '''
        Make an axes patch and outline.
        '''
        ax = self.ax
        ax.set_frame_on(False)
        ax.set_navigate(False)
        xy = self._outline(X, Y)
        ax.update_datalim(xy)
        ax.set_xlim(*ax.dataLim.intervalx)
        ax.set_ylim(*ax.dataLim.intervaly)
        self.outline = mlines.Line2D(xy[:, 0], xy[:, 1], color=mpl.rcParams['axes.edgecolor'],
                                     linewidth=mpl.rcParams['axes.linewidth'])
        ax.add_artist(self.outline)
        self.outline.set_clip_box(None)
        self.outline.set_clip_path(None)
        c = mpl.rcParams['axes.facecolor']
        self.patch = mpatches.Polygon(xy, edgecolor=c,
                                      facecolor=c,
                                      linewidth=0.01,
                                      zorder=-1)
        ax.add_artist(self.patch)
        ticks, ticklabels, offset_string = self._ticker()

        if self.orientation == 'vertical':
            ax.set_yticks(ticks)
            ax.set_yticklabels(ticklabels)
            ax.yaxis.get_major_formatter().set_offset_string(offset_string)

        else:
            ax.set_xticks(ticks)
            ax.set_xticklabels(ticklabels)
            ax.xaxis.get_major_formatter().set_offset_string(offset_string)



class CbarAxesBase(object):

    def colorbar(self, mappable, **kwargs):
        locator=kwargs.pop("locator", None)
        if locator is None:
            locator = ticker.MaxNLocator(5)
        self.locator = locator

        kwargs["ticks"] = locator


        self.hold(True)
        if self.orientation in  ["top", "bottom"]:
            orientation="horizontal"
        else:
            orientation="vertical"

        cb = Colorbar(self, mappable, orientation=orientation, **kwargs)
        self._config_axes()

        def on_changed(m):
            #print 'calling on changed', m.get_cmap().name
            cb.set_cmap(m.get_cmap())
            cb.set_clim(m.get_clim())
            cb.update_bruteforce(m)

        self.cbid = mappable.callbacksSM.connect('changed', on_changed)
        mappable.set_colorbar(cb, self)
        return cb

    def _config_axes(self):
        '''
        Make an axes patch and outline.
        '''
        ax = self
        ax.set_navigate(False)

        ax.axis[:].toggle(all=False)
        b = self._default_label_on
        ax.axis[self.orientation].toggle(all=b)

        # for axis in ax.axis.values():
        #     axis.major_ticks.set_visible(False)
        #     axis.minor_ticks.set_visible(False)
        #     axis.major_ticklabels.set_visible(False)
        #     axis.minor_ticklabels.set_visible(False)
        #     axis.label.set_visible(False)

        # axis = ax.axis[self.orientation]
        # axis.major_ticks.set_visible(True)
        # axis.minor_ticks.set_visible(True)


        #axis.major_ticklabels.set_size(int(axis.major_ticklabels.get_size()*.9))
        #axis.major_tick_pad = 3

        # axis.major_ticklabels.set_visible(b)
        # axis.minor_ticklabels.set_visible(b)
        # axis.label.set_visible(b)


    def toggle_label(self, b):
        self._default_label_on = b
        axis = self.axis[self.orientation]
        axis.toggle(ticklabels=b, label=b)
        #axis.major_ticklabels.set_visible(b)
        #axis.minor_ticklabels.set_visible(b)
        #axis.label.set_visible(b)



class CbarAxes(CbarAxesBase, LocatableAxes):
    def __init__(self, *kl, **kwargs):
        orientation=kwargs.pop("orientation", None)
        if orientation is None:
            raise ValueError("orientation must be specified")
        self.orientation = orientation
        self._default_label_on = True
        self.locator = None

        super(LocatableAxes, self).__init__(*kl, **kwargs)

    def cla(self):
        super(LocatableAxes, self).cla()
        self._config_axes()


class Grid(object):
    """
    A class that creates a grid of Axes. In matplotlib, the axes
    location (and size) is specified in the normalized figure
    coordinates. This may not be ideal for images that needs to be
    displayed with a given aspect ratio.  For example, displaying
    images of a same size with some fixed padding between them cannot
    be easily done in matplotlib. AxesGrid is used in such case.
    """

    _defaultLocatableAxesClass = LocatableAxes

    def __init__(self, fig,
                 rect,
                 nrows_ncols,
                 ngrids = None,
                 direction="row",
                 axes_pad = 0.02,
                 add_all=True,
                 share_all=False,
                 share_x=True,
                 share_y=True,
                 #aspect=True,
                 label_mode="L",
                 axes_class=None,
                 ):
        """
        Build an :class:`Grid` instance with a grid nrows*ncols
        :class:`~matplotlib.axes.Axes` in
        :class:`~matplotlib.figure.Figure` *fig* with
        *rect=[left, bottom, width, height]* (in
        :class:`~matplotlib.figure.Figure` coordinates) or
        the subplot position code (e.g., "121").

        Optional keyword arguments:

          ================  ========  =========================================
          Keyword           Default   Description
          ================  ========  =========================================
          direction         "row"     [ "row" | "column" ]
          axes_pad          0.02      float| pad between axes given in inches
          add_all           True      [ True | False ]
          share_all         False     [ True | False ]
          share_x           True      [ True | False ]
          share_y           True      [ True | False ]
          label_mode        "L"       [ "L" | "1" | "all" ]
          axes_class        None      a type object which must be a subclass
                                      of :class:`~matplotlib.axes.Axes`
          ================  ========  =========================================
        """
        self._nrows, self._ncols = nrows_ncols

        if ngrids is None:
            ngrids = self._nrows * self._ncols
        else:
            if (ngrids > self._nrows * self._ncols) or  (ngrids <= 0):
                raise Exception("")

        self.ngrids = ngrids

        self._init_axes_pad(axes_pad)

        if direction not in ["column", "row"]:
            raise Exception("")

        self._direction = direction


        if axes_class is None:
            axes_class = self._defaultLocatableAxesClass
            axes_class_args = {}
        else:
            if (type(axes_class)) == type and \
                   issubclass(axes_class, self._defaultLocatableAxesClass.Axes):
                axes_class_args = {}
            else:
                axes_class, axes_class_args = axes_class

        self.axes_all = []
        self.axes_column = [[] for i in range(self._ncols)]
        self.axes_row = [[] for i in range(self._nrows)]


        h = []
        v = []
        if cbook.is_string_like(rect) or cbook.is_numlike(rect):
            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
                                           aspect=False)
        elif isinstance(rect, SubplotSpec):
            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
                                           aspect=False)
        elif len(rect) == 3:
            kw = dict(horizontal=h, vertical=v, aspect=False)
            self._divider = SubplotDivider(fig, *rect, **kw)
        elif len(rect) == 4:
            self._divider = Divider(fig, rect, horizontal=h, vertical=v,
                                    aspect=False)
        else:
            raise Exception("")


        rect = self._divider.get_position()

        # reference axes
        self._column_refax = [None for i in range(self._ncols)]
        self._row_refax = [None for i in range(self._nrows)]
        self._refax = None

        for i in range(self.ngrids):

            col, row = self._get_col_row(i)

            if share_all:
                sharex = self._refax
                sharey = self._refax
            else:
                if share_x:
                    sharex = self._column_refax[col]
                else:
                    sharex = None

                if share_y:
                    sharey = self._row_refax[row]
                else:
                    sharey = None

            ax = axes_class(fig, rect, sharex=sharex, sharey=sharey,
                            **axes_class_args)

            if share_all:
                if self._refax is None:
                    self._refax = ax
            else:
                if sharex is None:
                    self._column_refax[col] = ax
                if sharey is None:
                    self._row_refax[row] = ax

            self.axes_all.append(ax)
            self.axes_column[col].append(ax)
            self.axes_row[row].append(ax)

        self.axes_llc = self.axes_column[0][-1]

        self._update_locators()

        if add_all:
            for ax in self.axes_all:
                fig.add_axes(ax)

        self.set_label_mode(label_mode)


    def _init_axes_pad(self, axes_pad):
        self._axes_pad = axes_pad

        self._horiz_pad_size = Size.Fixed(axes_pad)
        self._vert_pad_size = Size.Fixed(axes_pad)


    def _update_locators(self):

        h = []

        h_ax_pos = []
        h_cb_pos = []

        for ax in self._column_refax:
            #if h: h.append(Size.Fixed(self._axes_pad))
            if h: h.append(self._horiz_pad_size)

            h_ax_pos.append(len(h))

            sz = Size.Scaled(1)
            h.append(sz)

        v = []

        v_ax_pos = []
        v_cb_pos = []
        for ax in self._row_refax[::-1]:
            #if v: v.append(Size.Fixed(self._axes_pad))
            if v: v.append(self._vert_pad_size)

            v_ax_pos.append(len(v))
            sz = Size.Scaled(1)
            v.append(sz)


        for i in range(self.ngrids):
            col, row = self._get_col_row(i)
            locator = self._divider.new_locator(nx=h_ax_pos[col],
                                                ny=v_ax_pos[self._nrows -1 - row])
            self.axes_all[i].set_axes_locator(locator)

        self._divider.set_horizontal(h)
        self._divider.set_vertical(v)



    def _get_col_row(self, n):
        if self._direction == "column":
            col, row = divmod(n, self._nrows)
        else:
            row, col = divmod(n, self._ncols)

        return col, row


    def __getitem__(self, i):
        return self.axes_all[i]


    def get_geometry(self):
        """
        get geometry of the grid. Returns a tuple of two integer,
        representing number of rows and number of columns.
        """
        return self._nrows, self._ncols

    def set_axes_pad(self, axes_pad):
        "set axes_pad"
        self._axes_pad = axes_pad

        self._horiz_pad_size.fixed_size = axes_pad
        self._vert_pad_size.fixed_size = axes_pad


    def get_axes_pad(self):
        "get axes_pad"
        return self._axes_pad

    def set_aspect(self, aspect):
        "set aspect"
        self._divider.set_aspect(aspect)

    def get_aspect(self):
        "get aspect"
        return self._divider.get_aspect()

    def set_label_mode(self, mode):
        "set label_mode"
        if mode == "all":
            for ax in self.axes_all:
                _tick_only(ax, False, False)
        elif mode == "L":
            # left-most axes
            for ax in self.axes_column[0][:-1]:
                _tick_only(ax, bottom_on=True, left_on=False)
            # lower-left axes
            ax = self.axes_column[0][-1]
            _tick_only(ax, bottom_on=False, left_on=False)

            for col in self.axes_column[1:]:
                # axes with no labels
                for ax in col[:-1]:
                    _tick_only(ax, bottom_on=True, left_on=True)

                # bottom
                ax = col[-1]
                _tick_only(ax, bottom_on=False, left_on=True)

        elif mode == "1":
            for ax in self.axes_all:
                _tick_only(ax, bottom_on=True, left_on=True)

            ax = self.axes_llc
            _tick_only(ax, bottom_on=False, left_on=False)

    def get_divider(self):
        return self._divider
        
    def set_axes_locator(self, locator):
        self._divider.set_locator(locator)

    def get_axes_locator(self):
        return self._divider.get_locator()

    def get_vsize_hsize(self):

        return self._divider.get_vsize_hsize()
#         from axes_size import AddList

#         vsize = AddList(self._divider.get_vertical())
#         hsize = AddList(self._divider.get_horizontal())

#         return vsize, hsize




class ImageGrid(Grid):
    """
    A class that creates a grid of Axes. In matplotlib, the axes
    location (and size) is specified in the normalized figure
    coordinates. This may not be ideal for images that needs to be
    displayed with a given aspect ratio.  For example, displaying
    images of a same size with some fixed padding between them cannot
    be easily done in matplotlib. ImageGrid is used in such case.
    """

    _defaultCbarAxesClass = CbarAxes

    def __init__(self, fig,
                 rect,
                 nrows_ncols,
                 ngrids = None,
                 direction="row",
                 axes_pad = 0.02,
                 add_all=True,
                 share_all=False,
                 aspect=True,
                 label_mode="L",
                 cbar_mode=None,
                 cbar_location="right",
                 cbar_pad=None,
                 cbar_size="5%",
                 cbar_set_cax=True,
                 axes_class=None,
                 ):
        """
        Build an :class:`ImageGrid` instance with a grid nrows*ncols
        :class:`~matplotlib.axes.Axes` in
        :class:`~matplotlib.figure.Figure` *fig* with
        *rect=[left, bottom, width, height]* (in
        :class:`~matplotlib.figure.Figure` coordinates) or
        the subplot position code (e.g., "121").

        Optional keyword arguments:

          ================  ========  =========================================
          Keyword           Default   Description
          ================  ========  =========================================
          direction         "row"     [ "row" | "column" ]
          axes_pad          0.02      float| pad between axes given in inches
          add_all           True      [ True | False ]
          share_all         False     [ True | False ]
          aspect            True      [ True | False ]
          label_mode        "L"       [ "L" | "1" | "all" ]
          cbar_mode         None      [ "each" | "single" ]
          cbar_location     "right"   [ "right" | "top" ]
          cbar_pad          None
          cbar_size         "5%"
          cbar_set_cax      True      [ True | False ]
          axes_class        None      a type object which must be a subclass
                                      of :class:`~matplotlib.axes.Axes`
          ================  ========  =========================================

        *cbar_set_cax* : if True, each axes in the grid has a cax
          attribute that is bind to associated cbar_axes.
        """
        self._nrows, self._ncols = nrows_ncols

        if ngrids is None:
            ngrids = self._nrows * self._ncols
        else:
            if (ngrids > self._nrows * self._ncols) or  (ngrids <= 0):
                raise Exception("")

        self.ngrids = ngrids

        self._axes_pad = axes_pad

        self._colorbar_mode = cbar_mode
        self._colorbar_location = cbar_location
        if cbar_pad is None:
            self._colorbar_pad = axes_pad
        else:
            self._colorbar_pad = cbar_pad

        self._colorbar_size = cbar_size

        self._init_axes_pad(axes_pad)

        if direction not in ["column", "row"]:
            raise Exception("")

        self._direction = direction


        if axes_class is None:
            axes_class = self._defaultLocatableAxesClass
            axes_class_args = {}
        else:
            if isinstance(axes_class, maxes.Axes):
                axes_class_args = {}
            else:
                axes_class, axes_class_args = axes_class



        self.axes_all = []
        self.axes_column = [[] for i in range(self._ncols)]
        self.axes_row = [[] for i in range(self._nrows)]

        self.cbar_axes = []

        h = []
        v = []
        if cbook.is_string_like(rect) or cbook.is_numlike(rect):
            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
                                           aspect=aspect)
        elif isinstance(rect, SubplotSpec):
            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
                                           aspect=aspect)
        elif len(rect) == 3:
            kw = dict(horizontal=h, vertical=v, aspect=aspect)
            self._divider = SubplotDivider(fig, *rect, **kw)
        elif len(rect) == 4:
            self._divider = Divider(fig, rect, horizontal=h, vertical=v,
                                    aspect=aspect)
        else:
            raise Exception("")


        rect = self._divider.get_position()

        # reference axes
        self._column_refax = [None for i in range(self._ncols)]
        self._row_refax = [None for i in range(self._nrows)]
        self._refax = None

        for i in range(self.ngrids):

            col, row = self._get_col_row(i)

            if share_all:
                sharex = self._refax
                sharey = self._refax
            else:
                sharex = self._column_refax[col]
                sharey = self._row_refax[row]

            ax = axes_class(fig, rect, sharex=sharex, sharey=sharey,
                            **axes_class_args)

            if share_all:
                if self._refax is None:
                    self._refax = ax
            else:
                if sharex is None:
                    self._column_refax[col] = ax
                if sharey is None:
                    self._row_refax[row] = ax

            self.axes_all.append(ax)
            self.axes_column[col].append(ax)
            self.axes_row[row].append(ax)

            cax = self._defaultCbarAxesClass(fig, rect,
                                             orientation=self._colorbar_location)
            self.cbar_axes.append(cax)

        self.axes_llc = self.axes_column[0][-1]

        self._update_locators()

        if add_all:
            for ax in self.axes_all+self.cbar_axes:
                fig.add_axes(ax)

        if cbar_set_cax:
            if self._colorbar_mode == "single":
                for ax in self.axes_all:
                    ax.cax = self.cbar_axes[0]
            else:
                for ax, cax in zip(self.axes_all, self.cbar_axes):
                    ax.cax = cax

        self.set_label_mode(label_mode)



    def _update_locators(self):

        h = []

        h_ax_pos = []
        h_cb_pos = []
        for ax in self._column_refax:
            if h: h.append(self._horiz_pad_size) #Size.Fixed(self._axes_pad))

            h_ax_pos.append(len(h))

            if ax:
                sz = Size.AxesX(ax)
            else:
                sz = Size.AxesX(self.axes_llc)
            h.append(sz)

            if self._colorbar_mode == "each" and self._colorbar_location == "right":
                h.append(Size.from_any(self._colorbar_pad, sz))
                h_cb_pos.append(len(h))
                h.append(Size.from_any(self._colorbar_size, sz))


        v = []

        v_ax_pos = []
        v_cb_pos = []
        for ax in self._row_refax[::-1]:
            if v: v.append(self._horiz_pad_size) #Size.Fixed(self._axes_pad))

            v_ax_pos.append(len(v))
            if ax:
                sz = Size.AxesY(ax)
            else:
                sz = Size.AxesY(self.axes_llc)
            v.append(sz)


            if self._colorbar_mode == "each" and self._colorbar_location == "top":
                v.append(Size.from_any(self._colorbar_pad, sz))
                v_cb_pos.append(len(v))
                v.append(Size.from_any(self._colorbar_size, sz))


        for i in range(self.ngrids):
            col, row = self._get_col_row(i)
            #locator = self._divider.new_locator(nx=4*col, ny=2*(self._nrows - row - 1))
            locator = self._divider.new_locator(nx=h_ax_pos[col],
                                                ny=v_ax_pos[self._nrows -1 - row])
            self.axes_all[i].set_axes_locator(locator)

            if self._colorbar_mode == "each":
                if self._colorbar_location == "right":
                    locator = self._divider.new_locator(nx=h_cb_pos[col],
                                                        ny=v_ax_pos[self._nrows -1 - row])
                elif self._colorbar_location == "top":
                    locator = self._divider.new_locator(nx=h_ax_pos[col],
                                                        ny=v_cb_pos[self._nrows -1 - row])
                self.cbar_axes[i].set_axes_locator(locator)


        if self._colorbar_mode == "single":
            if self._colorbar_location == "right":
                #sz = Size.Fraction(Size.AxesX(self.axes_llc), self._nrows)
                sz = Size.Fraction(self._nrows, Size.AxesX(self.axes_llc))
                h.append(Size.from_any(self._colorbar_pad, sz))
                h.append(Size.from_any(self._colorbar_size, sz))
                locator = self._divider.new_locator(nx=-2, ny=0, ny1=-1)
            elif self._colorbar_location == "top":
                #sz = Size.Fraction(Size.AxesY(self.axes_llc), self._ncols)
                sz = Size.Fraction(self._ncols, Size.AxesY(self.axes_llc))
                v.append(Size.from_any(self._colorbar_pad, sz))
                v.append(Size.from_any(self._colorbar_size, sz))
                locator = self._divider.new_locator(nx=0, nx1=-1, ny=-2)
            for i in range(self.ngrids):
                self.cbar_axes[i].set_visible(False)
            self.cbar_axes[0].set_axes_locator(locator)
            self.cbar_axes[0].set_visible(True)
        elif self._colorbar_mode == "each":
            for i in range(self.ngrids):
                self.cbar_axes[i].set_visible(True)
        else:
            for i in range(self.ngrids):
                self.cbar_axes[i].set_visible(False)
                self.cbar_axes[i].set_position([1., 1., 0.001, 0.001],
                                               which="active")

        self._divider.set_horizontal(h)
        self._divider.set_vertical(v)


AxesGrid = ImageGrid



#if __name__ == "__main__":
if 0:
    F = plt.figure(1, (7, 6))
    F.clf()

    F.subplots_adjust(left=0.15, right=0.9)

    grid = Grid(F, 111, # similar to subplot(111)
                nrows_ncols = (2, 2),
                direction="row",
                axes_pad = 0.05,
                add_all=True,
                label_mode = "1",
                )



#if __name__ == "__main__":
if 0:
    from axes_divider import get_demo_image
    F = plt.figure(1, (9, 3.5))
    F.clf()

    F.subplots_adjust(left=0.05, right=0.98)

    grid = ImageGrid(F, 131, # similar to subplot(111)
                    nrows_ncols = (2, 2),
                    direction="row",
                    axes_pad = 0.05,
                    add_all=True,
                    label_mode = "1",
                    )

    Z, extent = get_demo_image()
    plt.ioff()
    for i in range(4):
        im = grid[i].imshow(Z, extent=extent, interpolation="nearest")

    # This only affects axes in first column and second row as share_all = False.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2])
    plt.ion()


    grid = ImageGrid(F, 132, # similar to subplot(111)
                    nrows_ncols = (2, 2),
                    direction="row",
                    axes_pad = 0.0,
                    add_all=True,
                    share_all=True,
                    label_mode = "1",
                    cbar_mode="single",
                    )

    Z, extent = get_demo_image()
    plt.ioff()
    for i in range(4):
        im = grid[i].imshow(Z, extent=extent, interpolation="nearest")
    plt.colorbar(im, cax = grid.cbar_axes[0])
    plt.setp(grid.cbar_axes[0].get_yticklabels(), visible=False)

    # This affects all axes as share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2])

    plt.ion()



    grid = ImageGrid(F, 133, # similar to subplot(122)
                    nrows_ncols = (2, 2),
                    direction="row",
                    axes_pad = 0.1,
                    add_all=True,
                    label_mode = "1",
                    share_all = True,
                    cbar_location="top",
                    cbar_mode="each",
                    cbar_size="7%",
                    cbar_pad="2%",
                    )
    plt.ioff()
    for i in range(4):
        im = grid[i].imshow(Z, extent=extent, interpolation="nearest")
        plt.colorbar(im, cax = grid.cbar_axes[i],
                     orientation="horizontal")
        grid.cbar_axes[i].xaxis.set_ticks_position("top")
        plt.setp(grid.cbar_axes[i].get_xticklabels(), visible=False)

    # This affects all axes as share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2])

    plt.ion()
    plt.draw()