Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
chaco / plots / quiverplot.py
Size: Mime:
# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

from numpy import array, compress, matrix, newaxis, sqrt, zeros

# Enthought library imports
from enable.api import ColorTrait
from traits.api import Array, Enum, Float, Instance, Int

# Chaco relative imports
from chaco.abstract_data_source import AbstractDataSource
from chaco.plots.scatterplot import ScatterPlot


class QuiverPlot(ScatterPlot):

    #: Determines how to interpret the data in the **vectors** data source.
    #:   "vector": each tuple is a (dx, dy)
    #:   "radial": each tuple is an (r, theta)
    data_type = Enum("vector", "radial")  # TODO: implement "radial"

    #: A datasource that returns an Nx2 array array indicating directions
    #: of the vectors.  The interpretation of this array is dependent on
    #: the setting of the **data_type** attribute.
    #:
    #: Usually this will be a MultiArrayDataSource.
    vectors = Instance(AbstractDataSource)

    # ------------------------------------------------------------------------
    # Visual attributes of the vector
    # ------------------------------------------------------------------------

    #: The color of the lines
    line_color = ColorTrait("black")

    #: The width of the lines
    line_width = Float(1.0)

    #: The length, in pixels, of the arrowhead
    arrow_size = Int(5)

    # ------------------------------------------------------------------------
    # Private traits
    # ------------------------------------------------------------------------

    _cached_vector_data = Array
    _selected_vector_data = Array

    def _gather_points_old(self):
        # In addition to the standard scatterplot _gather_points, we need
        # to also grab the vectors that fall inside the view range
        super()._gather_points_old()

        if not self.index or not self.value:
            return

        if len(self._cached_point_mask) == 0:
            self._cached_vector_data = []
            return

        vectors = self.vectors.get_data()
        self._cached_vector_data = compress(
            self._cached_point_mask, vectors, axis=0
        )

        if self._cached_selected_pts is not None:
            indices = self._cached_selection_point_mask
            self._selected_vector_data = compress(indices, vectors, axis=0)
        else:
            self._selected_vector_data = None

    def _render(self, gc, points, icon_mode=False):
        if len(points) < 1:
            return

        with gc:
            gc.clip_to_rect(self.x, self.y, self.width, self.height)

            gc.set_stroke_color(self.line_color_)
            gc.set_line_width(self.line_width)

            # Draw the body of the arrow
            starts = points
            ends = points + self._cached_vector_data
            gc.begin_path()
            gc.line_set(starts, ends)
            gc.stroke_path()

            if self.arrow_size > 0:
                vec = self._cached_vector_data
                unit_vec = (
                    vec / sqrt(vec[:, 0] ** 2 + vec[:, 1] ** 2)[:, newaxis]
                )
                a = 0.707106781  # sqrt(2)/2

                # Draw the left arrowhead (for an arrow pointing straight up)
                arrow_ends = (
                    ends
                    - array(unit_vec * matrix([[a, a], [-a, a]]))
                    * self.arrow_size
                )
                gc.begin_path()
                gc.line_set(ends, arrow_ends)
                gc.stroke_path()

                # Draw the left arrowhead (for an arrow pointing straight up)
                arrow_ends = (
                    ends
                    - array(unit_vec * matrix([[a, -a], [a, a]]))
                    * self.arrow_size
                )
                gc.begin_path()
                gc.line_set(ends, arrow_ends)
                gc.stroke_path()