Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
chaco / plots / multi_line_plot.py
Size: Mime:
# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

""" Defines the MultiLinePlot class.
"""


# Standard library imports
import warnings
from math import ceil, floor

# Major library imports
import numpy as np
from numpy import argsort, array, invert, isnan, take, transpose

# Enthought library imports
from enable.api import black_color_trait, ColorTrait, LineStyle
from traits.api import (
    Float,
    List,
    Str,
    Trait,
    Bool,
    Callable,
    Property,
    cached_property,
    Instance,
    Array,
)
from traitsui.api import Item, View, ScrubberEditor, HGroup

from chaco.array_data_source import ArrayDataSource
from chaco.base import arg_find_runs, bin_search
from chaco.base_xy_plot import BaseXYPlot


class MultiLinePlot(BaseXYPlot):
    """A plot consisting of multiple lines.

    The data to be plotted must come from a two-dimensional array with shape M by N
    stored in a MultiArrayDataSource object.  M is the number of lines to be plotted,
    and N is the number of points in each line.

    Parameters
    ----------
    index : instance of an ArrayDataSource
        These are the 'x' or abscissa coordinates.

    yindex : instance of ArrayDataSource
        These are the 'y' coordinates.

    value : instance of a MultiArrayDataSource
        Note that the `scale`, `offset` and `normalized_amplitude` attributes of the
        MultiLinePlot control the projection of the traces into the (x,y)
        plot.  In simplest case, `scale=1` and `offset=0`, and `normalized_amplitude`
        controls the scaling of the traces relative to their base y value.

    global_min, global_max : float
        The minimum and maximum values of the data in `value`.  For large
        arrays, computing these could take excessive time, so they must be
        provided when an instance is created.

    normalized_amplitude : Float

    color : ColorTrait

    color_func : Callable or None
        If not None, this Callable overrides `color`.  The argument to `color_func`
        will be the integer index of the trace to be rendered.  `color_func` must
        return an RGBA 4-tuple.
        Default: None

    orientation : str
        Must be 'v' or 'h' (for 'vertical' or 'horizontal', respectively).  This is
        the orientation of the index axis (i.e. the 'x' axis).
        Default: 'h'

    fast_clip : bool
        If True, traces whose *base* 'y' coordinate is outside the value axis range
        are not plotted, even if some of the data in the curve extends into the plot
        region.
        Default: False

    line_width : float
        Width of the plotted lines.

    line_style :
        The style of the trace lines in the plot.

    The following are from the original LinePlot code, and are untested:

    selected_color
    selected_line_style

    """

    #: M and N appearing in the comments are as defined in the docstring.

    yindex = Instance(ArrayDataSource)

    # amplitude = Float(0.0)

    #: `scale` and `offset` provide a more general transformation, but are currently
    #: untested.
    scale = Float(1.0)
    offset = Float(0.0)

    fast_clip = Bool(False)

    #: The color of the lines.
    color = black_color_trait(requires_redraw=True)

    #: A function that returns the color of lines.  Overrides `color` if not None.
    color_func = Trait(None, None, Callable)

    #: The color to use to highlight the line when selected.
    selected_color = ColorTrait("lightyellow")

    #: The style of the selected line.
    selected_line_style = LineStyle("solid")

    #: The name of the key in self.metadata that holds the selection mask
    metadata_name = Str("selections")

    #: The thickness of the line.
    line_width = Float(1.0, requires_redraw=True)

    #: The line dash style.
    line_style = LineStyle(requires_redraw=True)

    use_global_bounds = Bool(True)

    #: Minimum value in the `value` data source.  This must be provided
    #: in the call to the constructor.
    global_min = Float

    #: Maximum value in the `value` data source.  This must be provided
    #: in the call to the constructor.
    global_max = Float

    #: Normalized amplitude is the value exposed to the user.
    normalized_amplitude = Float(-0.5)

    amplitude_scale = Property(
        Float,
        observe=[
            "global_min",
            "global_max",
            "value._data",
            "use_global_bounds",
            "yindex",
        ],
    )

    amplitude = Property(
        Float, observe=["normalized_amplitude", "amplitude_scale"]
    )

    # ------------------------------------------------------------------------
    # Private traits
    # ------------------------------------------------------------------------

    # The projected 2D numpy array.
    _trace_data = Property(
        Array,
        observe=[
            "index",
            "index.data_changed",
            "value",
            "value.data_changed",
            "yindex",
            "yindex.data_changed",
            "amplitude",
            "scale",
            "offset",
        ],
    )

    # Cached list of non-NaN arrays of (x,y) data-space points; regardless of
    # self.orientation, this is always stored as (index_pt, value_pt).  This is
    # different from the default BaseXYPlot definition.
    _cached_data_pts = List

    # Cached list of non-NaN arrays of (x,y) screen-space points.
    _cached_screen_pts = List

    # ------------------------------------------------------------------------
    #
    # ------------------------------------------------------------------------

    def trait_view(self, obj):
        """Create a minimalist View, with just the amplitude and color attributes."""
        # Minimalist TraitsUI View for customizing the plot: only the trace amplitude
        # and line color are exposed.
        view = View(
            HGroup(
                Item("use_global_bounds"),
                # Item('normalized_amplitude'),
                # Item('normalized_amplitude', editor=RangeEditor()),
                Item(
                    "normalized_amplitude",
                    editor=ScrubberEditor(
                        increment=0.2,
                        hover_color=0xFFFFFF,
                        active_color=0xA0CD9E,
                        border_color=0x0000FF,
                    ),
                ),
            ),
            Item("color", label="Trace color", style="simple"),
            width=480,
            title="Trace Plot Line Attributes",
            buttons=["OK", "Cancel"],
        )
        return view

    # ------------------------------------------------------------------------
    #
    # ------------------------------------------------------------------------

    # See base_xy_plot.py for these:
    ## def hittest(self, screen_pt, threshold=7.0):
    ## def interpolate(self, index_value):

    def get_screen_points(self):
        self._gather_points()
        scrn_pts_list = [
            [self.map_screen(ary) for ary in line]
            for line in self._cached_data_pts
        ]
        return scrn_pts_list

    # ------------------------------------------------------------------------
    # Private methods
    # ------------------------------------------------------------------------

    @cached_property
    def _get_amplitude_scale(self):
        """
        If the amplitude is set to this value, the largest trace deviation from
        its base y coordinate will be equal to the y coordinate spacing.
        """
        # Note: Like the rest of the current code, this ignores the `scale` attribute.

        if self.yindex is not None:
            coordinates = self.yindex.get_data()
        else:
            coordinates = []

        if len(coordinates) > 1:
            dy = coordinates[1] - coordinates[0]
            if dy == 0:
                dy = 1.0
        else:
            # default coordinate spacing if there is only 1 coordinate
            dy = 1.0

        if self.use_global_bounds:
            max_abs = max(abs(self.global_min), abs(self.global_max))
        else:
            data = self.value._data
            max_abs = np.max(np.abs(data))

        if max_abs == 0:
            amp_scale = 0.5 * dy
        else:
            amp_scale = 0.5 * dy / max_abs
        return amp_scale

    @cached_property
    def _get_amplitude(self):
        amplitude = self.normalized_amplitude * self.amplitude_scale
        return amplitude

    @cached_property
    def _get__trace_data(self):
        """Compute the transformed data."""

        # Get the array from `value`
        data = self.value._data
        coordinates = self.yindex.get_data()
        channel_data = (
            self.scale * (self.amplitude * data + coordinates[:, np.newaxis])
            + self.offset
        )
        return channel_data

    def _gather_points(self):
        """
        Collects the data points that are within the bounds of the plot and
        caches them.
        """

        if self._cache_valid:
            return

        if not self.index or not self.value:
            return

        index = self.index.get_data()
        varray = self._trace_data

        if varray.size == 0:
            self._cached_data_pts = []
            self._cached_valid = True
            return

        coordinates = self.yindex.get_data()

        if self.fast_clip:
            coord_min = float(coordinates[0])
            coord_max = coordinates[-1]
            slice_min = max(
                0,
                ceil(
                    (varray.shape[0] - 1)
                    * (self.value_range.low - coord_min)
                    / (coord_max - coord_min)
                ),
            )
            slice_max = min(
                varray.shape[0],
                1 + floor(
                    (varray.shape[0] - 1)
                    * (self.value_range.high - coord_min)
                    / (coord_max - coord_min)
                ),
            )
            varray = varray[slice_min:slice_max]
            # FIXME: The y coordinates must also be sliced to match varray.

        # Check to see if the data is completely outside the view region.
        outside = False
        # Check x coordinates.
        low, high = self.index.get_bounds()
        if low > self.index_range.high or high < self.index_range.low:
            outside = True

        # Check y coordinates. Use varray because it is nased on the yindex,
        # but has been shifted up or down depending on the values.
        ylow, yhigh = varray.min(), varray.max()
        if ylow > self.value_range.high or yhigh < self.value_range.low:
            outside = True

        if outside:
            self._cached_data_pts = []
            self._cached_valid = True
            return

        if (
            len(index) == 0
            or varray.shape[0] == 0
            or varray.shape[1] == 0
            or len(index) != varray.shape[1]
        ):
            self._cached_data_pts = []
            self._cache_valid = True
            return

        size_diff = varray.shape[1] - len(index)
        if size_diff > 0:
            warnings.warn(
                "Chaco.LinePlot: value.shape[1] %d - len(index) %d = %d\n"
                % (varray.shape[1], len(index), size_diff)
            )
            index_max = len(index)
            varray = varray[:, :index_max]
        else:
            index_max = varray.shape[1]
            index = index[:index_max]

        # Split the index and value raw data into non-NaN chunks.
        # nan_mask is a boolean M by N array.
        nan_mask = invert(isnan(varray)) & invert(isnan(index))
        blocks_list = []
        for nm in nan_mask:
            blocks = [b for b in arg_find_runs(nm, "flat") if nm[b[0]] != 0]
            blocks_list.append(blocks)

        line_points = []
        for k, blocks in enumerate(blocks_list):
            points = []
            for block in blocks:
                start, end = block
                block_index = index[start:end]
                block_value = varray[k, start:end]
                index_mask = self.index_mapper.range.mask_data(block_index)

                runs = [
                    r
                    for r in arg_find_runs(index_mask, "flat")
                    if index_mask[r[0]] != 0
                ]

                # Check to see if our data view region is between two points in the
                # index data.  If so, then we have to reverse map our current view
                # into the appropriate index and draw the bracketing points.
                if runs == []:
                    data_pt = self.map_data(
                        (self.x_mapper.low_pos, self.y_mapper.low_pos)
                    )
                    if self.index.sort_order == "none":
                        indices = argsort(index)
                        sorted_index = take(index, indices)
                        sorted_value = take(varray[k], indices)
                        sort = 1
                    else:
                        sorted_index = index
                        sorted_value = varray[k]
                        if self.index.sort_order == "ascending":
                            sort = 1
                        else:
                            sort = -1
                    ndx = bin_search(sorted_index, data_pt, sort)
                    if ndx == -1:
                        # bin_search can return -1 if data_pt is outside the bounds
                        # of the source data
                        continue

                    z = transpose(
                        array(
                            (
                                sorted_index[ndx : ndx + 2],
                                sorted_value[ndx : ndx + 2],
                            )
                        )
                    )
                    points.append(z)

                else:
                    # Expand the width of every group of points so we draw the lines
                    # up to their next point, outside the plot area
                    data_end = len(index_mask)
                    for run in runs:
                        start, end = run
                        if start != 0:
                            start -= 1
                        if end != data_end:
                            end += 1

                        run_data = transpose(
                            array(
                                (
                                    block_index[start:end],
                                    block_value[start:end],
                                )
                            )
                        )
                        points.append(run_data)
            line_points.append(points)

        self._cached_data_pts = line_points
        self._cache_valid = True

    # See base_xy_plot.py for:
    ## def _downsample(self):
    ## def _downsample_vectorized(self):

    def _render(self, gc, line_points, selected_points=None):

        if len(line_points) == 0:
            return

        with gc:
            gc.set_antialias(True)
            gc.clip_to_rect(self.x, self.y, self.width, self.height)

            render = self._render_normal

            if selected_points is not None:
                gc.set_stroke_color(self.selected_color_)
                gc.set_line_width(self.line_width + 10.0)
                gc.set_line_dash(self.selected_line_style_)
                render(gc, selected_points)

            if self.color_func is not None:
                # Existence of self.color_func overrides self.color.
                color_func = self.color_func
            else:
                color_func = lambda k: self.color_

            tmp = list(enumerate(line_points))
            # Note: the list is reversed for testing with _render_filled.
            for k, points in reversed(tmp):
                color = color_func(k)
                # Apply the alpha
                alpha = color[-1] if len(color) == 4 else 1
                color = color[:3] + (alpha * self.alpha,)
                gc.set_stroke_color(color)
                gc.set_line_width(self.line_width)
                gc.set_line_dash(self.line_style_)
                render(gc, points)

            # Draw the default axes, if necessary
            self._draw_default_axes(gc)

    def _render_normal(self, gc, points):
        for ary in points:
            if len(ary) > 0:
                gc.begin_path()
                gc.lines(ary)
                gc.stroke_path()

    def _render_icon(self, gc, x, y, width, height):
        with gc:
            gc.set_stroke_color(self.color_)
            gc.set_line_width(self.line_width)
            gc.set_line_dash(self.line_style_)
            gc.set_antialias(0)
            gc.move_to(x, y + height / 2)
            gc.line_to(x + width, y + height / 2)
            gc.stroke_path()

    def _amplitude_changed(self):
        self.value.data_changed = True
        self.invalidate_draw()
        self.request_redraw()