Repository URL to install this package:
Version:
6.0.0 ▾
|
# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
"""
A plot that renders text values in two dimensions
"""
from numpy import array, column_stack, empty, isfinite
# Enthought library imports
from enable.api import black_color_trait
from kiva.trait_defs.kiva_font_trait import KivaFont
from traits.api import Bool, Enum, Float, Int, Instance, List, Tuple, observe
# local imports
from chaco.array_data_source import ArrayDataSource
from chaco.label import Label
from chaco.base_xy_plot import BaseXYPlot
class TextPlot(BaseXYPlot):
""" A plot that positions textual labels in 2D """
#: text values corresponding to indices
text = Instance(ArrayDataSource)
#: The font of the tick labels.
text_font = KivaFont("sans-serif 10", redraw=True)
#: The color of the tick labels.
text_color = black_color_trait(redraw=True)
#: The rotation of the tick labels.
text_rotate_angle = Float(0, redraw=True)
#: The margin around the label.
text_margin = Int(2, redraw=True)
#: horizontal position of text relative to target point
h_position = Enum("center", "left", "right", redraw=True)
#: vertical position of text relative to target point
v_position = Enum("center", "top", "bottom", redraw=True)
#: offset of text relative to non-index direction in pixels
text_offset = Tuple(Float, Float, redraw=True)
# ------------------------------------------------------------------------
# Private traits
# ------------------------------------------------------------------------
#: flag for whether the cache of Label instances is valid
_label_cache_valid = Bool(False, transient=True)
#: cache of Label instances for faster rendering
_label_cache = List(transient=True)
#: cache of bounding boxes of labels
_label_box_cache = List(transient=True)
# ------------------------------------------------------------------------
# Private methods
# ------------------------------------------------------------------------
def _compute_labels(self, gc):
"""Generate the Label instances for the plot. """
self._label_cache = [
Label(
text=text,
font=self.text_font,
color=self.text_color,
rotate_angle=self.text_rotate_angle,
margin=self.text_margin,
)
for text in self.text.get_data()
]
self._label_box_cache = [
array(label.get_bounding_box(gc), float)
for label in self._label_cache
]
self._label_cache_valid = True
def _gather_points(self):
"""Abstract method to collect data points that are within the range of
the plot, and cache them.
"""
if self._cache_valid:
return
if not self.index or not self.value:
return
index, index_mask = self.index.get_data_mask()
value, value_mask = self.value.get_data_mask()
if len(index) == 0 or len(value) == 0 or len(index) != len(value):
self._cached_data_pts = []
self._cached_point_mask = []
self._cache_valid = True
return
index_range_mask = self.index_mapper.range.mask_data(index)
value_range_mask = self.value_mapper.range.mask_data(value)
nan_mask = isfinite(index) & index_mask & isfinite(value) & value_mask
point_mask = nan_mask & index_range_mask & value_range_mask
if not self._cache_valid:
if not point_mask.all():
points = column_stack([index[point_mask], value[point_mask]])
else:
points = column_stack([index, value])
self._cached_data_pts = points
self._cached_point_mask = point_mask
self._cache_valid = True
def _render(self, gc, pts):
if not self._label_cache_valid:
self._compute_labels(gc)
labels = [
label
for label, mask in zip(self._label_cache, self._cached_point_mask)
if mask
]
boxes = [
label
for label, mask in zip(
self._label_box_cache, self._cached_point_mask
)
if mask
]
offset = empty((2,), float)
with gc:
gc.clip_to_rect(self.x, self.y, self.width, self.height)
for pt, label, box in zip(pts, labels, boxes):
with gc:
if self.h_position == "center":
offset[0] = -box[0] / 2 + self.text_offset[0]
elif self.h_position == "right":
offset[0] = self.text_offset[0]
elif self.h_position == "left":
offset[0] = -box[0] / 2 + self.text_offset[0]
if self.v_position == "center":
offset[1] = -box[1] / 2 + self.text_offset[1]
elif self.v_position == "top":
offset[1] = self.text_offset[1]
elif self.v_position == "bottom":
offset[1] = -box[1] / 2 - self.text_offset[1]
pt += offset
gc.translate_ctm(*pt)
label.draw(gc)
# ------------------------------------------------------------------------
# Trait events
# ------------------------------------------------------------------------
@observe("index.data_changed")
def _invalidate(self, event):
self._cache_valid = False
self._screen_cache_valid = False
self._label_cache_valid = False
@observe("value.data_changed,+redraw")
def _invalidate_labels(self, event):
self._label_cache_valid = False