Repository URL to install this package:
Version:
5.0.0 ▾
|
# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
""" Defines the BarPlot class.
"""
import logging
from numpy import (
array,
compress,
column_stack,
invert,
isnan,
transpose,
zeros,
)
from traits.api import (
Any,
Bool,
Enum,
Float,
Instance,
Property,
Range,
Tuple,
cached_property,
)
from enable.api import black_color_trait
from kiva.constants import FILL_STROKE
# Local relative imports
from chaco.abstract_plot_renderer import AbstractPlotRenderer
from chaco.abstract_mapper import AbstractMapper
from chaco.array_data_source import ArrayDataSource
from chaco.base import reverse_map_1d
logger = logging.getLogger(__name__)
# TODO: make child of BaseXYPlot
class BarPlot(AbstractPlotRenderer):
"""
A renderer for bar charts.
"""
#: The data source to use for the index coordinate.
index = Instance(ArrayDataSource)
#: The data source to use as value points.
value = Instance(ArrayDataSource)
#: The data source to use as "starting" values for bars (along value axis).
#: For instance, if the values are [10, 20] and starting_value
#: is [3, 7], BarPlot will plot two bars, one between 3 and 10, and
#: one between 7 and 20
starting_value = Instance(ArrayDataSource)
#: Labels for the indices.
index_mapper = Instance(AbstractMapper)
#: Labels for the values.
value_mapper = Instance(AbstractMapper)
#: The orientation of the index axis.
orientation = Enum("h", "v")
#: The direction of the index axis with respect to the graphics context's
#: direction.
index_direction = Enum("normal", "flipped")
#: The direction of the value axis with respect to the graphics context's
#: direction.
value_direction = Enum("normal", "flipped")
#: Type of width used for bars:
#:
#: 'data'
#: The width is in the units along the x-dimension of the data space.
#: 'screen'
#: The width uses a fixed width of pixels.
bar_width_type = Enum("data", "screen")
#: Width of the bars, in data or screen space (determined by
#: **bar_width_type**).
bar_width = Float(10)
#: Round on rectangle dimensions? This is not strictly an "antialias", but
#: it has the same effect through exact pixel drawing.
antialias = Bool(True)
#: Width of the border of the bars.
line_width = Float(1.0, requires_redraw=True)
#: Color of the border of the bars.
line_color = black_color_trait(requires_redraw=True)
#: Color to fill the bars.
fill_color = black_color_trait(requires_redraw=True)
#: The RGBA tuple for rendering lines. It is always a tuple of length 4.
#: It has the same RGB values as :attr:`line_color`, and its alpha value
#: is the alpha value of self.line_color multiplied by self.alpha.
effective_line_color = Property(Tuple, observe=["line_color", "alpha"])
#: The RGBA tuple for rendering the fill. It is always a tuple of length
#: 4. It has the same RGB values as :attr:`fill_color`, and its alpha
#: value is the alpha value of self.fill_color multiplied by self.alpha.
effective_fill_color = Property(Tuple, observe=["fill_color", "alpha"])
#: Overall alpha value of the image. Ranges from 0.0 for transparent to 1.0
alpha = Range(0.0, 1.0, 1.0, requires_redraw=True)
# Convenience properties that correspond to either index_mapper or
# value_mapper, depending on the orientation of the plot.
#: Corresponds to either **index_mapper** or **value_mapper**, depending on
#: the orientation of the plot.
x_mapper = Property
#: Corresponds to either **value_mapper** or **index_mapper**, depending on
#: the orientation of the plot.
y_mapper = Property
#: Corresponds to either **index_direction** or **value_direction**,
#: depending on the orientation of the plot.
x_direction = Property
#: Corresponds to either **value_direction** or **index_direction**,
#: depending on the orientation of the plot
y_direction = Property
#: Convenience property for accessing the index data range.
index_range = Property
#: Convenience property for accessing the value data range.
value_range = Property
# ------------------------------------------------------------------------
# Private traits
# ------------------------------------------------------------------------
# Indicates whether or not the data cache is valid
_cache_valid = Bool(False)
# Cached data values from the datasources. If **bar_width_type** is "data",
# then this is an Nx4 array of (bar_left, bar_right, start, end) for a
# bar plot in normal orientation. If **bar_width_type** is "screen", then
# this is an Nx3 array of (bar_center, start, end).
_cached_data_pts = Any
# ------------------------------------------------------------------------
# AbstractPlotRenderer interface
# ------------------------------------------------------------------------
def __init__(self, *args, **kw):
# These Traits depend on others, so we'll defer setting them until
# after the HasTraits initialization has been completed.
later_list = ["index_direction", "value_direction"]
postponed = {}
for name in later_list:
if name in kw:
postponed[name] = kw.pop(name)
super().__init__(*args, **kw)
# Set any keyword Traits that were postponed.
self.trait_set(**postponed)
def map_screen(self, data_array):
"""Maps an array of data points into screen space and returns it as
an array.
Implements the AbstractPlotRenderer interface.
"""
# data_array is Nx2 array
if len(data_array) == 0:
return []
x_ary, y_ary = transpose(data_array)
sx = self.index_mapper.map_screen(x_ary)
sy = self.value_mapper.map_screen(y_ary)
if self.orientation == "h":
return transpose(array((sx, sy)))
else:
return transpose(array((sy, sx)))
def map_data(self, screen_pt):
"""Maps a screen space point into the "index" space of the plot.
Implements the AbstractPlotRenderer interface.
"""
if self.orientation == "h":
screen_coord = screen_pt[0]
else:
screen_coord = screen_pt[1]
return self.index_mapper.map_data(screen_coord)
def map_index(
self,
screen_pt,
threshold=2.0,
outside_returns_none=True,
index_only=False,
):
"""Maps a screen space point to an index into the plot's index array(s).
Implements the AbstractPlotRenderer interface.
"""
data_pt = self.map_data(screen_pt)
if (
(data_pt < self.index_mapper.range.low)
or (data_pt > self.index_mapper.range.high)
) and outside_returns_none:
return None
index_data = self.index.get_data()
value_data = self.value.get_data()
if len(value_data) == 0 or len(index_data) == 0:
return None
try:
ndx = reverse_map_1d(index_data, data_pt, self.index.sort_order)
except IndexError:
return None
x = index_data[ndx]
y = value_data[ndx]
result = self.map_screen(array([[x, y]]))
if result is None:
return None
sx, sy = result[0]
if index_only and ((screen_pt[0] - sx) < threshold):
return ndx
elif (screen_pt[0] - sx) ** 2 + (
screen_pt[1] - sy
) ** 2 < threshold * threshold:
return ndx
else:
return None
# ------------------------------------------------------------------------
# PlotComponent interface
# ------------------------------------------------------------------------
def _gather_points(self):
"""Collects data points that are within the range of the plot, and
caches them in **_cached_data_pts**.
"""
index, index_mask = self.index.get_data_mask()
value, value_mask = self.value.get_data_mask()
if not self.index or not self.value:
return
if len(index) == 0 or len(value) == 0 or len(index) != len(value):
logger.warning(
"Chaco: using empty dataset; index_len=%d, value_len=%d."
% (len(index), len(value))
)
self._cached_data_pts = array([])
self._cache_valid = True
return
# TODO: Until we code up a better handling of value-based culling that
# takes into account starting_value and dataspace bar widths, just use
# the index culling for now.
# value_range_mask = self.value_mapper.range.mask_data(value)
# nan_mask = invert(isnan(index_mask)) & invert(isnan(value_mask))
# point_mask = index_mask & value_mask & nan_mask & \
# index_range_mask & value_range_mask
index_range_mask = self.index_mapper.range.mask_data(index)
nan_mask = invert(isnan(index_mask))
point_mask = index_mask & nan_mask & index_range_mask
if self.starting_value is None:
starting_values = zeros(len(index))
else:
starting_values = self.starting_value.get_data()
if self.bar_width_type == "data":
half_width = self.bar_width / 2.0
points = column_stack(
(
index - half_width,
index + half_width,
starting_values,
value,
)
)
else:
points = column_stack((index, starting_values, value))
self._cached_data_pts = compress(point_mask, points, axis=0)
self._cache_valid = True
def _draw_plot(self, gc, view_bounds=None, mode="normal"):
"""Draws the 'plot' layer."""
if not self._cache_valid:
self._gather_points()
data = self._cached_data_pts
if data.size == 0:
# Nothing to draw.
return
with gc:
gc.clip_to_rect(self.x, self.y, self.width, self.height)
gc.set_antialias(self.antialias)
gc.set_stroke_color(self.effective_line_color)
gc.set_fill_color(self.effective_fill_color)
gc.set_line_width(self.line_width)
if self.bar_width_type == "data":
# map the bar start and stop locations into screen space
lower_left_pts = self.map_screen(data[:, (0, 2)])
upper_right_pts = self.map_screen(data[:, (1, 3)])
else:
half_width = self.bar_width / 2.0
# map the bar centers into screen space and then compute the bar
# start and end positions
lower_left_pts = self.map_screen(data[:, (0, 1)])
upper_right_pts = self.map_screen(data[:, (0, 2)])
lower_left_pts[:, 0] -= half_width
upper_right_pts[:, 0] += half_width
bounds = upper_right_pts - lower_left_pts
gc.rects(column_stack((lower_left_pts, bounds)))
gc.draw_path()
def _draw_default_axes(self, gc):
if not self.origin_axis_visible:
return
with gc:
gc.set_stroke_color(self.origin_axis_color_)
gc.set_line_width(self.origin_axis_width)
gc.set_line_dash(None)
for range in (self.index_mapper.range, self.value_mapper.range):
if (range.low < 0) and (range.high > 0):
if range == self.index_mapper.range:
dual = self.value_mapper.range
data_pts = array([[0.0, dual.low], [0.0, dual.high]])
else:
dual = self.index_mapper.range
data_pts = array([[dual.low, 0.0], [dual.high, 0.0]])
start, end = self.map_screen(data_pts)
gc.move_to(int(start[0]) + 0.5, int(start[1]) + 0.5)
gc.line_to(int(end[0]) + 0.5, int(end[1]) + 0.5)
gc.stroke_path()
def _render_icon(self, gc, x, y, width, height):
with gc:
gc.set_fill_color(self.effective_fill_color)
gc.set_stroke_color(self.effective_line_color)
gc.rect(x + width / 4, y + height / 4, width / 2, height / 2)
gc.draw_path(FILL_STROKE)
def _post_load(self):
super()._post_load()
# ------------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------------
def _get_index_range(self):
return self.index_mapper.range
def _set_index_range(self, val):
self.index_mapper.range = val
def _get_value_range(self):
return self.value_mapper.range
def _set_value_range(self, val):
self.value_mapper.range = val
def _get_x_mapper(self):
if self.orientation == "h":
return self.index_mapper
else:
return self.value_mapper
def _get_y_mapper(self):
if self.orientation == "h":
return self.value_mapper
else:
return self.index_mapper
def _get_x_direction(self):
if self.orientation == "h":
return self.index_direction
else:
return self.value_direction
def _get_y_direction(self):
if self.orientation == "h":
return self.value_direction
else:
return self.index_direction
# ------------------------------------------------------------------------
# Event handlers - these are mostly copied from BaseXYPlot
# ------------------------------------------------------------------------
def _update_mappers(self):
"""Updates the index and value mappers. Called by trait change handlers
for various traits.
"""
x_mapper = self.index_mapper
y_mapper = self.value_mapper
x_dir = self.index_direction
y_dir = self.value_direction
if self.orientation == "v":
x_mapper, y_mapper = y_mapper, x_mapper
x_dir, y_dir = y_dir, x_dir
x = self.x
x2 = self.x2
y = self.y
y2 = self.y2
if x_mapper is not None:
if x_dir == "normal":
x_mapper.low_pos = x
x_mapper.high_pos = x2
else:
x_mapper.low_pos = x2
x_mapper.high_pos = x
if y_mapper is not None:
if y_dir == "normal":
y_mapper.low_pos = y
y_mapper.high_pos = y2
else:
y_mapper.low_pos = y2
y_mapper.high_pos = y
self.invalidate_draw()
self._cache_valid = False
def _bounds_changed(self, old, new):
super()._bounds_changed(old, new)
self._update_mappers()
def _bounds_items_changed(self, event):
super()._bounds_items_changed(event)
self._update_mappers()
def _orientation_changed(self):
self._update_mappers()
def _index_changed(self, old, new):
if old is not None:
old.observe(self._either_data_updated, "data_changed", remove=True)
if new is not None:
new.observe(self._either_data_updated, "data_changed")
self._either_data_updated()
def _index_direction_changed(self):
m = self.index_mapper
m.low_pos, m.high_pos = m.high_pos, m.low_pos
self.invalidate_draw()
def _value_direction_changed(self):
m = self.value_mapper
m.low_pos, m.high_pos = m.high_pos, m.low_pos
self.invalidate_draw()
def _either_data_updated(self, event=None):
self.invalidate_draw()
self._cache_valid = False
self.request_redraw()
def _value_changed(self, old, new):
if old is not None:
old.observe(self._either_data_updated, "data_changed", remove=True)
if new is not None:
new.observe(self._either_data_updated, "data_changed")
self._either_data_updated()
def _index_mapper_changed(self, old, new):
return self._either_mapper_changed(old, new)
def _value_mapper_changed(self, old, new):
return self._either_mapper_changed(old, new)
def _either_mapper_changed(self, old, new):
if old is not None:
old.observe(self._mapper_updated_handler, "updated", remove=True)
if new is not None:
new.observe(self._mapper_updated_handler, "updated")
self.invalidate_draw()
def _mapper_updated_handler(self, event):
self._cache_valid = False
self.invalidate_draw()
self.request_redraw()
def _bar_width_changed(self):
self._cache_valid = False
self.invalidate_draw()
self.request_redraw()
def _bar_width_type_changed(self):
self._cache_valid = False
self.invalidate_draw()
self.request_redraw()
# ------------------------------------------------------------------------
# Property getters
# ------------------------------------------------------------------------
@cached_property
def _get_effective_line_color(self):
if len(self.line_color_) == 4:
line_alpha = self.line_color_[-1]
else:
line_alpha = 1.0
c = self.line_color_[:3] + (line_alpha * self.alpha,)
return c
@cached_property
def _get_effective_fill_color(self):
if len(self.fill_color_) == 4:
fill_alpha = self.fill_color_[-1]
else:
fill_alpha = 1.0
c = self.fill_color_[:3] + (fill_alpha * self.alpha,)
return c