Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
chaco / colormapped_scatterplot.py
Size: Mime:
""" Defines the ColormappedScatterPlot and ColormappedScatterPlotView classes.
"""

from __future__ import with_statement

import six
import six.moves as sm
# Major library imports
from numpy import argsort, array, concatenate, nonzero, invert, take, \
                  isnan, transpose, newaxis, zeros, ndarray

# Enthought library imports
from kiva.constants import STROKE
from traits.api import Dict, Enum, Float, Instance, on_trait_change
from traitsui.api import Item, RangeEditor

# Local, relative imports
from .array_data_source import ArrayDataSource
from .base import left_shift, right_shift
from .abstract_colormap import AbstractColormap
from .scatterplot import ScatterPlot, ScatterPlotView


class ColormappedScatterPlotView(ScatterPlotView):
    """ Traits UI View for customizing a color-mapped scatter plot.
    """
    def __init__(self):
        super(ColormappedScatterPlotView, self).__init__()
        vgroup = self.content
        vgroup.content[0].content.append(Item("fill_alpha", label="Fill alpha",
                                   editor=RangeEditor(low=0.0, high=1.0)))
        return


class ColormappedScatterPlot(ScatterPlot):
    """
    A scatter plot that allows each point to take on a different color,
    corresponding to a color map.

    If the **color_data** or **color_mapper** attributes are None, then it
    behaves like a normal ScatterPlot.
    """

    #: Source for color data.
    color_data = Instance(ArrayDataSource)

    #: Mapping for colors.
    color_mapper = Instance(AbstractColormap)

    #: The alpha value to apply to the result of the color-mapping process.
    #: (This makes it easier to create color maps without having to worry
    #: about alpha.)
    fill_alpha = Float(1.0)

    #: Determines what drawing approach to use:
    #:
    #: banded:
    #:     Draw the points color-band by color-band, thus reducing the number of
    #:     set_stroke_color() calls. Disadvantage is that some colors will
    #:     appear more prominently than others if there are a lot of
    #:     overlapping points.
    #: bruteforce:
    #:     Set the stroke color before drawing each marker.  Slower, but doesn't
    #:     produce the banding effect that puts some colors on top of others;
    #:     useful if there is a lot of overlap of the data.
    #: auto:
    #:     Determines which render method to use based on the number of points
    #:
    #: TODO: Based on preliminary results, "banded" isn't significantly
    #: more expensive than "bruteforce" for small datasets (<1000),
    #: so perhaps banded should be removed.
    render_method = Enum("auto", "banded", "bruteforce")

    # A dict mapping color-map indices to arrays of indices into self.data.
    # This is used for the "banded" render method.
    # This mapping is only valid if **_cache_valid** is True.
    _index_bands = Dict()

    #: Traits UI View for customizing the plot. Overrides the ScatterPlot value.
    traits_view = ColormappedScatterPlotView()

    #------------------------------------------------------------------------
    # BaseXYPlot interface
    #------------------------------------------------------------------------

    def map_screen(self, data_array):
        """
        Maps an array of data points into screen space, and returns them as
        an array.

        The *data_array* parameter must be an Nx2 (index, value) or Nx3
        (index, value, color_value) array. The returned array is an Nx2
        array of (x, y) tuples.
        """
        if len(data_array)>0:
            if data_array.shape[1] == 3:
                data_array = data_array[:, :2]
        return super(ColormappedScatterPlot, self).map_screen(data_array)

    def _draw_plot(self, gc, view_bounds=None, mode="normal"):
        """ Draws the 'plot' layer.

        Overrides BaseXYPlot, which isn't really fully generic (it assumes that
        the output of map_screen() is sufficient to render the data).
        """
        self._gather_points()
        if len(self._cached_data_pts) == 0:
            pass
        elif self._cached_data_pts.shape[1] == 2:
            # Take into account fill_alpha even if we are rendering with only two values
            old_color = self.color
            self.color = tuple(self.fill_alpha * array(self.color_))
            super(ColormappedScatterPlot, self)._draw_component(gc, view_bounds, mode)
            self.color = old_color
        else:
            colors = self._cached_data_pts[:,2]
            screen_pts = self.map_screen(self._cached_data_pts)
            pts = concatenate((screen_pts, colors[:, newaxis]), axis=1)
            self._render(gc, pts)
        return

    def _gather_points(self):
        """
        Collects the data points that are within the plot bounds and caches them
        """
        if self._cache_valid:
            return

        if not self.index or not self.value:
            self._cached_data_pts = []
            self._cache_valid = True
            return

        index, index_mask = self.index.get_data_mask()
        value, value_mask = self.value.get_data_mask()

        if len(index) == 0 or len(value) == 0 or len(index) != len(value):
            self._cached_data_pts = []
            self._cache_valid = True
            return

        index_range_mask = self.index_mapper.range.mask_data(index)
        value_range_mask = self.value_mapper.range.mask_data(value)
        nan_mask = invert(isnan(index_mask)) & invert(isnan(value_mask))
        point_mask = index_mask & value_mask & nan_mask & \
                     index_range_mask & value_range_mask

        if self.color_data is not None:
            if self.color_data.is_masked():
                color_data, color_mask = self.color_data.get_data_mask()
                point_mask = point_mask & color_mask
            else:
                color_data = self.color_data.get_data()

            #color_nan_mask = isreal(color_data)
            color_nan_mask = invert(isnan(color_data))

            point_mask = point_mask & color_nan_mask
            points = transpose(array((index, value, color_data)))
        else:
            points = transpose(array((index, value)))


        self._cached_data_pts = points[point_mask]
        self._cached_point_mask = point_mask

        self._cache_valid = True
        return

    def _render(self, gc, points):
        """ Actually draws the plot.

        Overrides the ScatterPlot implementation.
        """
        # If we don't have a color data set, then use the base class to render
        if (self.color_mapper is None) or (self.color_data is None):
            return super(ColormappedScatterPlot, self)._render(gc, points)

        # If the GC doesn't have draw_*_at_points, then use bruteforce
        if hasattr(gc, 'draw_marker_at_points') or hasattr(gc, 'draw_path_at_points'):
            batch_capable = True
        else:
            batch_capable = False

        if self.render_method == 'auto':
            method = self._calc_render_method(len(points))
        else:
            method = self.render_method

        with gc:
            if method == 'bruteforce' or (not batch_capable):
                self._render_bruteforce(gc, points)
            elif method == 'banded':
                self._render_banded(gc, points)
        return


    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    def _compute_bands(self, points, smartmode=False):
        """
        Sorts self.data into a list of arrays of data points by color,
        filling in self._index_bands.  If *smartmode* is True, then it first
        calls _calc_render_method() to see which rendering method is
        optimal for the number of points and the distribution of
        color indices; if the rendering method is 'bruteforce', then
        this method short-circuits and returns without doing
        anything.
        """
        if len(points) == 0:
            return
        if self.color_mapper is None:
            return

        # map the V values in the (x,y,v) self.data array

        color_data = points[:,2]
        color_indices = self.color_mapper.map_index(color_data)

        if smartmode and self.render_method == 'bruteforce':
             pass
        else:
            # shuffle_indices indicates how to sort the points in self.data
            # so that their color_indices are in order.  We don't really care
            # about the sorting so much as the fact that once they are sorted,
            # points of the same color are grouped together into "bands".
            shuffle_indices = argsort(color_indices)

            # This pulls values from the color_indices array into
            # sorted_color_indices, using the results of the sort we just did.
            sorted_color_indices = take(color_indices, shuffle_indices)

            # Now we want to determine where the continuous bands are.  We do
            # this by right-shifting the sorted_color_indices array, subtracting
            # it from the original, and looking for all the nonzero points.
            shifted = right_shift(sorted_color_indices, sorted_color_indices[0])
            start_indices = concatenate([[0], nonzero(sorted_color_indices - shifted)[0]])
            end_indices = left_shift(start_indices, len(sorted_color_indices))

            # Store the shuffled indices in self._index_bands.  We don't store the
            # actual data points because we need to allow the renderer to index into
            # the mapped XY screen positions.
            self._index_bands = {}
            for (start, end) in zip(start_indices, end_indices):
                color_index = sorted_color_indices[start]
                self._index_bands[color_index] = shuffle_indices[start:end]

        self._color_indices = color_indices
        self._cache_valid = True
        return

    def _calc_render_method(self, numpoints):
        """ Returns a string indicating the render method.
        """
        if numpoints > 1000 and isinstance(self.marker_size, float):
            return 'banded'
        else:
            return "bruteforce"

    def _set_draw_info(self, gc, mode, color, outline_color=None, outline_weight=None):
        """ Sets the stroke color, fill color, and line width on the graphics
        context.
        """
        color = tuple(color[:3]) + (self.fill_alpha,)
        if mode == STROKE:
            if outline_color is not None:
                gc.set_stroke_color( color )
        else:
            if outline_color is not None:
                gc.set_stroke_color( outline_color )
            gc.set_fill_color( color )
        if outline_weight is not None:
            gc.set_line_width(outline_weight)
        return

    def _render_banded(self, gc, points):
        """ Draws the points color-band by color-band.
        """
        self._compute_bands(points)

        # Grab the XY values corresponding to each color band of points

        xy_points = points[:,0:2]

        marker = self.marker_
        size = self.marker_size
        assert isinstance(size, float), "Variable size markers not implemented for banded rendering"

        # Set up the GC for drawing
        gc.set_line_dash( None )
        if marker.draw_mode == STROKE:
            gc.set_line_width(self.line_width)

        gc.begin_path()

        cmap = self.color_mapper

        if (hasattr(gc, 'draw_marker_at_points') and self.marker not in ('custom', 'circle', 'diamond')):
            # This is the fastest method: we use one of the built-in markers.
            color_bands = cmap.color_bands
            # Initial setup of drawing parameters
            self._set_draw_info(gc, marker.draw_mode, color_bands[0],
                                self.outline_color_, self.line_width)
            index_bands = self._index_bands
            mode = marker.draw_mode
            for color_index in index_bands.keys():
                self._set_draw_info(gc, mode, color_bands[color_index])
                gc.draw_marker_at_points(xy_points[index_bands[color_index]], size, marker.kiva_marker)


        elif hasattr( gc, 'draw_path_at_points' ):
            point_bands = {}
            for color_index, indices in self._index_bands.items():
                point_bands[color_index] = xy_points[indices]
            # We have to construct the path for the marker.
            if self.marker != 'custom':
                path = gc.get_empty_path()
                # turn the class into an instance... we should make add_to_path a
                # class method at some point.
                marker().add_to_path(path, size)
                mode = marker.draw_mode
            else:
                path = self.custom_symbol
                mode = STROKE

            color_bands = cmap.color_bands
            for color_index, xy in point_bands.items():
                self._set_draw_info(gc, mode, color_bands[color_index],
                                    self.outline_color_, self.line_width)
                gc.draw_path_at_points(xy, path, mode)
        else:
            raise RuntimeError("Batch drawing requested on non-batch-capable GC.")
        return

    def _render_bruteforce(self, gc, points):
        """ Draws the points, setting the stroke color for each one.
        """
        x, y, colors = transpose(points)

        # Map the colors
        colors = self.color_mapper.map_screen(colors)
        alphas = (zeros(len(colors))+self.fill_alpha)[:, newaxis]
        colors = concatenate((colors[:, :3], alphas), axis=1)

        with gc:
            gc.clip_to_rect(self.x, self.y, self.width, self.height)
            gc.set_stroke_color(self.outline_color_)
            gc.set_line_width(self.line_width)

            marker_cls = self.marker_
            marker_size = self.marker_size
            if isinstance(marker_size, ndarray) and self._cached_point_mask is not None:
                marker_size = marker_size[self._cached_point_mask]
            mode = marker_cls.draw_mode

            if marker_cls != "custom":
                if (hasattr(gc, "draw_marker_at_points") and self.marker not in ('custom', 'circle', 'diamond')):
                    draw_func = lambda x, y, size: gc.draw_marker_at_points([[x,y]], size, marker_cls.kiva_marker)

                elif hasattr(gc, "draw_path_at_points"):
                    # turn the class into an instance... we should make add_to_path a
                    # class method at some point.
                    m = marker_cls()
                    def draw_func(x, y, size):
                        path = gc.get_empty_path()
                        m.add_to_path(path, size)
                        gc.draw_path_at_points([[x, y]], path, mode)
                else:
                    m = marker_cls()
                    def draw_func(x, y, size):
                        gc.translate_ctm(x, y)
                        gc.begin_path()
                        m.add_to_path(gc, size)
                        gc.draw_path(mode)
                        gc.translate_ctm(-x, -y)

                for i in range(len(x)):
                    if isinstance(marker_size, float):
                        size = marker_size
                    else:
                        size = marker_size[i]
                    gc.set_fill_color(colors[i])
                    draw_func(x[i], y[i], size)

            else:
                path = marker_cls.custom_symbol
                for i in range(len(x)):
                    gc.set_fill_color(colors[i])
                    gc.draw_path_at_points([[x[i], y[i]]], path, STROKE)


    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _color_data_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self._either_data_changed, "data_changed", remove=True)
        if new is not None:
            new.on_trait_change(self._either_data_changed, "data_changed")
        self._either_data_changed()
        return

    def _color_mapper_changed(self, old, new):
        self._cache_valid = False

        if hasattr(new, 'range') and new.range is None and old is not None:
            # Someone passed in a ColorMapper that has no range associated with
            # it. Use the range on the old ColorMapper.
            new.range = old.range

        self.invalidate_draw()
        self.request_redraw()

        return

    @on_trait_change('color_mapper:updated')
    def _color_mapper_updated(self):
        self.invalidate_draw()
        self.request_redraw()

    def _fill_alpha_changed(self):
        self.invalidate_draw()
        self.request_redraw()
        return


# EOF