Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aaronreidsmith / matplotlib   python

Repository URL to install this package:

Version: 3.1.2 

/ axisartist / grid_finder.py

import numpy as np

import matplotlib.ticker as mticker
from matplotlib.transforms import Bbox, Transform
from .clip_path import clip_line_to_rect


# extremes finder
class ExtremeFinderSimple(object):
    def __init__(self, nx, ny):
        self.nx, self.ny = nx, ny

    def __call__(self, transform_xy, x1, y1, x2, y2):
        """
        get extreme values.

        x1, y1, x2, y2 in image coordinates (0-based)
        nx, ny : number of division in each axis
        """
        x_, y_ = np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny)
        x, y = np.meshgrid(x_, y_)
        lon, lat = transform_xy(np.ravel(x), np.ravel(y))

        lon_min, lon_max = lon.min(), lon.max()
        lat_min, lat_max = lat.min(), lat.max()

        return self._add_pad(lon_min, lon_max, lat_min, lat_max)

    def _add_pad(self, lon_min, lon_max, lat_min, lat_max):
        """ a small amount of padding is added because the current
        clipping algorithms seems to fail when the gridline ends at
        the bbox boundary.
        """
        dlon = (lon_max - lon_min) / self.nx
        dlat = (lat_max - lat_min) / self.ny

        lon_min, lon_max = lon_min - dlon, lon_max + dlon
        lat_min, lat_max = lat_min - dlat, lat_max + dlat

        return lon_min, lon_max, lat_min, lat_max


class GridFinderBase(object):
    def __init__(self,
                 extreme_finder,
                 grid_locator1,
                 grid_locator2,
                 tick_formatter1=None,
                 tick_formatter2=None):
        """
        the transData of the axes to the world coordinate.
        locator1, locator2 : grid locator for 1st and 2nd axis.

        Derived must define "transform_xy, inv_transform_xy"
        (may use update_transform)
        """
        super().__init__()

        self.extreme_finder = extreme_finder
        self.grid_locator1 = grid_locator1
        self.grid_locator2 = grid_locator2
        self.tick_formatter1 = tick_formatter1
        self.tick_formatter2 = tick_formatter2

    def get_grid_info(self,
                      x1, y1, x2, y2):
        """
        lon_values, lat_values : list of grid values. if integer is given,
                           rough number of grids in each direction.
        """

        extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)

        # min & max rage of lat (or lon) for each grid line will be drawn.
        # i.e., gridline of lon=0 will be drawn from lat_min to lat_max.

        lon_min, lon_max, lat_min, lat_max = extremes
        lon_levs, lon_n, lon_factor = \
                  self.grid_locator1(lon_min, lon_max)
        lat_levs, lat_n, lat_factor = \
                  self.grid_locator2(lat_min, lat_max)

        if lon_factor is None:
            lon_values = np.asarray(lon_levs[:lon_n])
        else:
            lon_values = np.asarray(lon_levs[:lon_n]/lon_factor)
        if lat_factor is None:
            lat_values = np.asarray(lat_levs[:lat_n])
        else:
            lat_values = np.asarray(lat_levs[:lat_n]/lat_factor)

        lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
                                                        lat_values,
                                                        lon_min, lon_max,
                                                        lat_min, lat_max)

        ddx = (x2-x1)*1.e-10
        ddy = (y2-y1)*1.e-10
        bb = Bbox.from_extents(x1-ddx, y1-ddy, x2+ddx, y2+ddy)

        grid_info = {}
        grid_info["extremes"] = extremes
        grid_info["lon_lines"] = lon_lines
        grid_info["lat_lines"] = lat_lines

        grid_info["lon"] = self._clip_grid_lines_and_find_ticks(lon_lines,
                                                                lon_values,
                                                                lon_levs,
                                                                bb)

        grid_info["lat"] = self._clip_grid_lines_and_find_ticks(lat_lines,
                                                                lat_values,
                                                                lat_levs,
                                                                bb)

        tck_labels = grid_info["lon"]["tick_labels"] = dict()
        for direction in ["left", "bottom", "right", "top"]:
            levs = grid_info["lon"]["tick_levels"][direction]
            tck_labels[direction] = self.tick_formatter1(direction,
                                                         lon_factor, levs)

        tck_labels = grid_info["lat"]["tick_labels"] = dict()
        for direction in ["left", "bottom", "right", "top"]:
            levs = grid_info["lat"]["tick_levels"][direction]
            tck_labels[direction] = self.tick_formatter2(direction,
                                                         lat_factor, levs)

        return grid_info

    def _get_raw_grid_lines(self,
                            lon_values, lat_values,
                            lon_min, lon_max, lat_min, lat_max):

        lons_i = np.linspace(lon_min, lon_max, 100)  # for interpolation
        lats_i = np.linspace(lat_min, lat_max, 100)

        lon_lines = [self.transform_xy(np.zeros_like(lats_i) + lon, lats_i)
                     for lon in lon_values]
        lat_lines = [self.transform_xy(lons_i, np.zeros_like(lons_i) + lat)
                     for lat in lat_values]

        return lon_lines, lat_lines

    def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
        gi = {
            "values": [],
            "levels": [],
            "tick_levels": dict(left=[], bottom=[], right=[], top=[]),
            "tick_locs": dict(left=[], bottom=[], right=[], top=[]),
            "lines": [],
        }

        tck_levels = gi["tick_levels"]
        tck_locs = gi["tick_locs"]
        for (lx, ly), v, lev in zip(lines, values, levs):
            xy, tcks = clip_line_to_rect(lx, ly, bb)
            if not xy:
                continue
            gi["levels"].append(v)
            gi["lines"].append(xy)

            for tck, direction in zip(tcks,
                                      ["left", "bottom", "right", "top"]):
                for t in tck:
                    tck_levels[direction].append(lev)
                    tck_locs[direction].append(t)

        return gi

    def update_transform(self, aux_trans):
        if isinstance(aux_trans, Transform):
            def transform_xy(x, y):
                ll1 = np.column_stack([x, y])
                ll2 = aux_trans.transform(ll1)
                lon, lat = ll2[:, 0], ll2[:, 1]
                return lon, lat

            def inv_transform_xy(x, y):
                ll1 = np.column_stack([x, y])
                ll2 = aux_trans.inverted().transform(ll1)
                lon, lat = ll2[:, 0], ll2[:, 1]
                return lon, lat

        else:
            transform_xy, inv_transform_xy = aux_trans

        self.transform_xy = transform_xy
        self.inv_transform_xy = inv_transform_xy

    def update(self, **kw):
        for k in kw:
            if k in ["extreme_finder",
                     "grid_locator1",
                     "grid_locator2",
                     "tick_formatter1",
                     "tick_formatter2"]:
                setattr(self, k, kw[k])
            else:
                raise ValueError("unknown update property '%s'" % k)


class GridFinder(GridFinderBase):

    def __init__(self,
                 transform,
                 extreme_finder=None,
                 grid_locator1=None,
                 grid_locator2=None,
                 tick_formatter1=None,
                 tick_formatter2=None):
        """
        transform : transform from the image coordinate (which will be
        the transData of the axes to the world coordinate.

        or transform = (transform_xy, inv_transform_xy)

        locator1, locator2 : grid locator for 1st and 2nd axis.
        """
        if extreme_finder is None:
            extreme_finder = ExtremeFinderSimple(20, 20)
        if grid_locator1 is None:
            grid_locator1 = MaxNLocator()
        if grid_locator2 is None:
            grid_locator2 = MaxNLocator()
        if tick_formatter1 is None:
            tick_formatter1 = FormatterPrettyPrint()
        if tick_formatter2 is None:
            tick_formatter2 = FormatterPrettyPrint()
        super().__init__(
            extreme_finder,
            grid_locator1,
            grid_locator2,
            tick_formatter1,
            tick_formatter2)
        self.update_transform(transform)


class MaxNLocator(mticker.MaxNLocator):
    def __init__(self, nbins=10, steps=None,
                 trim=True,
                 integer=False,
                 symmetric=False,
                 prune=None):
        # trim argument has no effect. It has been left for API compatibility
        mticker.MaxNLocator.__init__(self, nbins, steps=steps,
                                     integer=integer,
                                     symmetric=symmetric, prune=prune)
        self.create_dummy_axis()
        self._factor = None

    def __call__(self, v1, v2):
        if self._factor is not None:
            self.set_bounds(v1*self._factor, v2*self._factor)
            locs = mticker.MaxNLocator.__call__(self)
            return np.array(locs), len(locs), self._factor
        else:
            self.set_bounds(v1, v2)
            locs = mticker.MaxNLocator.__call__(self)
            return np.array(locs), len(locs), None

    def set_factor(self, f):
        self._factor = f


class FixedLocator(object):
    def __init__(self, locs):
        self._locs = locs
        self._factor = None

    def __call__(self, v1, v2):
        if self._factor is None:
            v1, v2 = sorted([v1, v2])
        else:
            v1, v2 = sorted([v1*self._factor, v2*self._factor])
        locs = np.array([l for l in self._locs if v1 <= l <= v2])
        return locs, len(locs), self._factor

    def set_factor(self, f):
        self._factor = f


# Tick Formatter

class FormatterPrettyPrint(object):
    def __init__(self, useMathText=True):
        self._fmt = mticker.ScalarFormatter(
            useMathText=useMathText, useOffset=False)
        self._fmt.create_dummy_axis()
        self._ignore_factor = True

    def __call__(self, direction, factor, values):
        if not self._ignore_factor:
            if factor is None:
                factor = 1
            values = [v / factor for v in values]
        return self._fmt.format_ticks(values)


class DictFormatter(object):
    def __init__(self, format_dict, formatter=None):
        """
        format_dict : dictionary for format strings to be used.
        formatter : fall-back formatter
        """
        super().__init__()
        self._format_dict = format_dict
        self._fallback_formatter = formatter

    def __call__(self, direction, factor, values):
        """
        factor is ignored if value is found in the dictionary
        """

        if self._fallback_formatter:
            fallback_strings = self._fallback_formatter(
                direction, factor, values)
        else:
            fallback_strings = [""]*len(values)

        r = [self._format_dict.get(k, v) for k, v in zip(values,
                                                         fallback_strings)]
        return r