Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
chaco / barplot.py
Size: Mime:
""" Defines the BarPlot class.
"""

from __future__ import with_statement

import logging

from numpy import array, compress, column_stack, invert, isnan, transpose, zeros
from traits.api import Any, Bool, Enum, Float, Instance, Property, \
        Range, Tuple, cached_property, on_trait_change
from enable.api import black_color_trait
from kiva.constants import FILL_STROKE

# Local relative imports
from .abstract_plot_renderer import AbstractPlotRenderer
from .abstract_mapper import AbstractMapper
from .array_data_source import ArrayDataSource
from .base import reverse_map_1d


logger = logging.getLogger(__name__)


# TODO: make child of BaseXYPlot

class BarPlot(AbstractPlotRenderer):
    """
    A renderer for bar charts.
    """
    #: The data source to use for the index coordinate.
    index = Instance(ArrayDataSource)

    #: The data source to use as value points.
    value = Instance(ArrayDataSource)

    #: The data source to use as "starting" values for bars (along value axis).
    #: For instance, if the values are [10, 20] and starting_value
    #: is [3, 7], BarPlot will plot two bars, one  between 3 and 10, and
    #: one between 7 and 20
    starting_value = Instance(ArrayDataSource)

    #: Labels for the indices.
    index_mapper = Instance(AbstractMapper)
    #: Labels for the values.
    value_mapper = Instance(AbstractMapper)

    #: The orientation of the index axis.
    orientation = Enum("h", "v")

    #: The direction of the index axis with respect to the graphics context's
    #: direction.
    index_direction = Enum("normal", "flipped")

    #: The direction of the value axis with respect to the graphics context's
    #: direction.
    value_direction = Enum("normal", "flipped")

    #: Type of width used for bars:
    #:
    #: 'data'
    #:     The width is in the units along the x-dimension of the data space.
    #: 'screen'
    #:     The width uses a fixed width of pixels.
    bar_width_type = Enum("data", "screen")

    #: Width of the bars, in data or screen space (determined by
    #: **bar_width_type**).
    bar_width = Float(10)

    #: Round on rectangle dimensions? This is not strictly an "antialias", but
    #: it has the same effect through exact pixel drawing.
    antialias = Bool(True)

    #: Width of the border of the bars.
    line_width = Float(1.0)
    #: Color of the border of the bars.
    line_color = black_color_trait
    #: Color to fill the bars.
    fill_color = black_color_trait

    #: The RGBA tuple for rendering lines.  It is always a tuple of length 4.
    #: It has the same RGB values as line_color_, and its alpha value is the
    #: alpha value of self.line_color multiplied by self.alpha. 
    effective_line_color = Property(Tuple, depends_on=['line_color', 'alpha'])
    
    #: The RGBA tuple for rendering the fill.  It is always a tuple of length 4.
    #: It has the same RGB values as fill_color_, and its alpha value is the
    #: alpha value of self.fill_color multiplied by self.alpha.   
    effective_fill_color = Property(Tuple, depends_on=['fill_color', 'alpha'])

    #: Overall alpha value of the image. Ranges from 0.0 for transparent to 1.0
    alpha = Range(0.0, 1.0, 1.0)


    #use_draw_order = False

    # Convenience properties that correspond to either index_mapper or
    # value_mapper, depending on the orientation of the plot.

    #: Corresponds to either **index_mapper** or **value_mapper**, depending on
    #: the orientation of the plot.
    x_mapper = Property
    #: Corresponds to either **value_mapper** or **index_mapper**, depending on
    #: the orientation of the plot.
    y_mapper = Property

    #: Corresponds to either **index_direction** or **value_direction**,
    #: depending on the orientation of the plot.
    x_direction = Property
    #: Corresponds to either **value_direction** or **index_direction**,
    #: depending on the orientation of the plot
    y_direction = Property

    #: Convenience property for accessing the index data range.
    index_range = Property
    #: Convenience property for accessing the value data range.
    value_range = Property


    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # Indicates whether or not the data cache is valid
    _cache_valid = Bool(False)

    # Cached data values from the datasources.  If **bar_width_type** is "data",
    # then this is an Nx4 array of (bar_left, bar_right, start, end) for a
    # bar plot in normal orientation.  If **bar_width_type** is "screen", then
    # this is an Nx3 array of (bar_center, start, end).
    _cached_data_pts = Any


    #------------------------------------------------------------------------
    # AbstractPlotRenderer interface
    #------------------------------------------------------------------------

    def __init__(self, *args, **kw):
        # These Traits depend on others, so we'll defer setting them until
        # after the HasTraits initialization has been completed.
        later_list = ['index_direction', 'value_direction']
        postponed = {}
        for name in later_list:
            if name in kw:
                postponed[name] = kw.pop(name)

        super(BarPlot, self).__init__(*args, **kw)

        # Set any keyword Traits that were postponed.
        self.trait_set(**postponed)


    def map_screen(self, data_array):
        """ Maps an array of data points into screen space and returns it as
        an array.

        Implements the AbstractPlotRenderer interface.
        """
        # data_array is Nx2 array
        if len(data_array) == 0:
            return []
        x_ary, y_ary = transpose(data_array)
        sx = self.index_mapper.map_screen(x_ary)
        sy = self.value_mapper.map_screen(y_ary)

        if self.orientation == "h":
            return transpose(array((sx,sy)))
        else:
            return transpose(array((sy,sx)))

    def map_data(self, screen_pt):
        """ Maps a screen space point into the "index" space of the plot.

        Implements the AbstractPlotRenderer interface.
        """
        if self.orientation == "h":
            screen_coord = screen_pt[0]
        else:
            screen_coord = screen_pt[1]
        return self.index_mapper.map_data(screen_coord)

    def map_index(self, screen_pt, threshold=2.0, outside_returns_none=True,
                  index_only=False):
        """ Maps a screen space point to an index into the plot's index array(s).

        Implements the AbstractPlotRenderer interface.
        """
        data_pt = self.map_data(screen_pt)
        if ((data_pt < self.index_mapper.range.low) or \
            (data_pt > self.index_mapper.range.high)) and outside_returns_none:
            return None
        index_data = self.index.get_data()
        value_data = self.value.get_data()

        if len(value_data) == 0 or len(index_data) == 0:
            return None

        try:
            ndx = reverse_map_1d(index_data, data_pt, self.index.sort_order)
        except IndexError:
            return None

        x = index_data[ndx]
        y = value_data[ndx]

        result = self.map_screen(array([[x,y]]))
        if result is None:
            return None

        sx, sy = result[0]
        if index_only and ((screen_pt[0]-sx) < threshold):
            return ndx
        elif ((screen_pt[0]-sx)**2 + (screen_pt[1]-sy)**2 < threshold*threshold):
            return ndx
        else:
            return None

    #------------------------------------------------------------------------
    # PlotComponent interface
    #------------------------------------------------------------------------

    def _gather_points(self):
        """ Collects data points that are within the range of the plot, and
        caches them in **_cached_data_pts**.
        """
        index, index_mask = self.index.get_data_mask()
        value, value_mask = self.value.get_data_mask()

        if not self.index or not self.value:
            return

        if len(index) == 0 or len(value) == 0 or len(index) != len(value):
            logger.warning(
                "Chaco: using empty dataset; index_len=%d, value_len=%d."
                % (len(index), len(value)))
            self._cached_data_pts = array([])
            self._cache_valid = True
            return

        # TODO: Until we code up a better handling of value-based culling that
        # takes into account starting_value and dataspace bar widths, just use
        # the index culling for now.
#        value_range_mask = self.value_mapper.range.mask_data(value)
#        nan_mask = invert(isnan(index_mask)) & invert(isnan(value_mask))
#        point_mask = index_mask & value_mask & nan_mask & \
#                     index_range_mask & value_range_mask

        index_range_mask = self.index_mapper.range.mask_data(index)
        nan_mask = invert(isnan(index_mask))
        point_mask = index_mask & nan_mask & index_range_mask

        if self.starting_value is None:
            starting_values = zeros(len(index))
        else:
            starting_values = self.starting_value.get_data()

        if self.bar_width_type == "data":
            half_width = self.bar_width / 2.0
            points = column_stack((index-half_width, index+half_width,
                                   starting_values, value))
        else:
            points = column_stack((index, starting_values, value))
        self._cached_data_pts = compress(point_mask, points, axis=0)

        self._cache_valid = True
        return

    def _draw_plot(self, gc, view_bounds=None, mode="normal"):
        """ Draws the 'plot' layer.
        """
        if not self._cache_valid:
            self._gather_points()

        data = self._cached_data_pts
        if data.size == 0:
            # Nothing to draw.
            return

        with gc:
            gc.clip_to_rect(self.x, self.y, self.width, self.height)
            gc.set_antialias(self.antialias)
            gc.set_stroke_color(self.effective_line_color)
            gc.set_fill_color(self.effective_fill_color)
            gc.set_line_width(self.line_width)

            if self.bar_width_type == "data":
                # map the bar start and stop locations into screen space
                lower_left_pts = self.map_screen(data[:,(0,2)])
                upper_right_pts = self.map_screen(data[:,(1,3)])
            else:
                half_width = self.bar_width / 2.0
                # map the bar centers into screen space and then compute the bar
                # start and end positions
                lower_left_pts = self.map_screen(data[:,(0,1)])
                upper_right_pts = self.map_screen(data[:,(0,2)])
                lower_left_pts[:,0] -= half_width
                upper_right_pts[:,0] += half_width

            bounds = upper_right_pts - lower_left_pts
            gc.rects(column_stack((lower_left_pts, bounds)))
            gc.draw_path()


    def _draw_default_axes(self, gc):
        if not self.origin_axis_visible:
            return

        with gc:
            gc.set_stroke_color(self.origin_axis_color_)
            gc.set_line_width(self.origin_axis_width)
            gc.set_line_dash(None)

            for range in (self.index_mapper.range, self.value_mapper.range):
                if (range.low < 0) and (range.high > 0):
                    if range == self.index_mapper.range:
                        dual = self.value_mapper.range
                        data_pts = array([[0.0,dual.low], [0.0, dual.high]])
                    else:
                        dual = self.index_mapper.range
                        data_pts = array([[dual.low,0.0], [dual.high,0.0]])
                    start,end = self.map_screen(data_pts)
                    gc.move_to(int(start[0])+0.5, int(start[1])+0.5)
                    gc.line_to(int(end[0])+0.5, int(end[1])+0.5)
                    gc.stroke_path()

        return

    def _render_icon(self, gc, x, y, width, height):
        with gc:
            gc.set_fill_color(self.effective_fill_color)
            gc.set_stroke_color(self.effective_line_color)
            gc.rect(x+width/4, y+height/4, width/2, height/2)
            gc.draw_path(FILL_STROKE)

    def _post_load(self):
        super(BarPlot, self)._post_load()
        return


    #------------------------------------------------------------------------
    # Properties
    #------------------------------------------------------------------------

    def _get_index_range(self):
        return self.index_mapper.range

    def _set_index_range(self, val):
        self.index_mapper.range = val

    def _get_value_range(self):
        return self.value_mapper.range

    def _set_value_range(self, val):
        self.value_mapper.range = val

    def _get_x_mapper(self):
        if self.orientation == "h":
            return self.index_mapper
        else:
            return self.value_mapper

    def _get_y_mapper(self):
        if self.orientation == "h":
            return self.value_mapper
        else:
            return self.index_mapper

    def _get_x_direction(self):
        if self.orientation == "h":
            return self.index_direction
        else:
            return self.value_direction

    def _get_y_direction(self):
        if self.orientation == "h":
            return self.value_direction
        else:
            return self.index_direction

    #------------------------------------------------------------------------
    # Event handlers - these are mostly copied from BaseXYPlot
    #------------------------------------------------------------------------

    def _update_mappers(self):
        """ Updates the index and value mappers. Called by trait change handlers
        for various traits.
        """
        x_mapper = self.index_mapper
        y_mapper = self.value_mapper
        x_dir = self.index_direction
        y_dir = self.value_direction

        if self.orientation == "v":
            x_mapper, y_mapper = y_mapper, x_mapper
            x_dir, y_dir = y_dir, x_dir

        x = self.x
        x2 = self.x2
        y = self.y
        y2 = self.y2

        if x_mapper is not None:
            if x_dir =="normal":
                x_mapper.low_pos = x
                x_mapper.high_pos = x2
            else:
                x_mapper.low_pos = x2
                x_mapper.high_pos = x

        if y_mapper is not None:
            if y_dir == "normal":
                y_mapper.low_pos = y
                y_mapper.high_pos = y2
            else:
                y_mapper.low_pos = y2
                y_mapper.high_pos = y

        self.invalidate_draw()
        self._cache_valid = False

    @on_trait_change('line_color, line_width, fill_color, alpha')
    def _attributes_changed(self):
        self.invalidate_draw()
        self.request_redraw()

    def _bounds_changed(self, old, new):
        super(BarPlot, self)._bounds_changed(old, new)
        self._update_mappers()

    def _bounds_items_changed(self, event):
        super(BarPlot, self)._bounds_items_changed(event)
        self._update_mappers()

    def _orientation_changed(self):
        self._update_mappers()

    def _index_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self._either_data_changed, "data_changed", remove=True)
        if new is not None:
            new.on_trait_change(self._either_data_changed, "data_changed")
        self._either_data_changed()

    def _index_direction_changed(self):
        m = self.index_mapper
        m.low_pos, m.high_pos = m.high_pos, m.low_pos
        self.invalidate_draw()

    def _value_direction_changed(self):
        m = self.value_mapper
        m.low_pos, m.high_pos = m.high_pos, m.low_pos
        self.invalidate_draw()

    def _either_data_changed(self):
        self.invalidate_draw()
        self._cache_valid = False
        self.request_redraw()

    def _value_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self._either_data_changed, "data_changed", remove=True)
        if new is not None:
            new.on_trait_change(self._either_data_changed, "data_changed")
        self._either_data_changed()

    def _index_mapper_changed(self, old, new):
        return self._either_mapper_changed(old, new)

    def _value_mapper_changed(self, old, new):
        return self._either_mapper_changed(old, new)

    def _either_mapper_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self._mapper_updated_handler, "updated", remove=True)
        if new is not None:
            new.on_trait_change(self._mapper_updated_handler, "updated")
        self.invalidate_draw()

    def _mapper_updated_handler(self):
        self._cache_valid = False
        self.invalidate_draw()
        self.request_redraw()

    def _bar_width_changed(self):
        self._cache_valid = False
        self.invalidate_draw()
        self.request_redraw()

    def _bar_width_type_changed(self):
        self._cache_valid = False
        self.invalidate_draw()
        self.request_redraw()

    #------------------------------------------------------------------------
    # Property getters
    #------------------------------------------------------------------------

    @cached_property
    def _get_effective_line_color(self):
        if len(self.line_color_) == 4:
            line_alpha = self.line_color_[-1]
        else:
            line_alpha = 1.0
        c = self.line_color_[:3] + (line_alpha * self.alpha,)
        return c

    @cached_property
    def _get_effective_fill_color(self):
        if len(self.fill_color_) == 4:
            fill_alpha = self.fill_color_[-1]
        else:
            fill_alpha = 1.0
        c = self.fill_color_[:3] + (fill_alpha * self.alpha,)
        return c


### EOF ####################################################################