Repository URL to install this package:
Version:
4.8.0 ▾
|
"""
Contains the logic behind creating and configuring new plots
from a set of user-supplied arguments.
"""
# Standard library imports
import io
import re
import six
# Major library imports
from numpy import all, array, arange, asarray, reshape, shape, transpose
# Chaco imports
from chaco.plot_factory import (create_line_plot, create_scatter_plot)
from chaco.array_data_source import ArrayDataSource
from chaco.image_data import ImageData
from chaco.tools.highlight_tool import HighlightTool
# Local relative imports
from .chaco_shell_error import ChacoShellError
# Normally I don't define an __all__, but this lets us distinguish
# the top level plot-producing functions from the various helper
# functions.
__all__ = ["do_plot", "do_imshow", "do_pcolor", "do_contour", "do_plotv",
"SizeMismatch", ]
#-----------------------------------------------------------------------------
# Exceptions
#-----------------------------------------------------------------------------
class SizeMismatch(ChacoShellError):
pass
#-----------------------------------------------------------------------------
# Utility functions
#-----------------------------------------------------------------------------
def is1D (a):
s = shape(a)
return ((len(s) == 1) or (s[0] == 1) or (s[1] == 1))
def is2D (a):
return (len(shape(a)) == 2)
def row ( a ):
return reshape( asarray( a ), [1,-1] )
def col ( a ):
return reshape( asarray( a ), [-1,1] )
#-----------------------------------------------------------------------------
# Plot commands for chaco-style plotv()
#-----------------------------------------------------------------------------
def do_plotv(session, *args, **kw):
""" Creates a list of plots from the data in ``*args`` and options in
``**kw``, according to the docstring on commands.plot().
"""
sort = kw.get("sort", "none")
sources_list = make_data_sources(session, sort, *args)
plot_type = kw.get("type", "line")
if plot_type == "scatter":
plots = [create_scatter_plot(sources) for sources in sources_list]
elif plot_type == "line":
plots = [create_line_plot(sources) for sources in sources_list]
else:
raise ChacoShellError("Unknown plot type '%s'." % plot_type)
for plot in plots:
plot.orientation = kw.get("orientation", "h")
return plots
def make_data_sources(session, index_sort="none", *args):
""" Given a list of arguments, returns a list of (index, value) datasources
to create plots from.
"""
# Make sure everything is a numpy array
data = []
for arg in args:
if isinstance(arg, list) or isinstance(arg, tuple):
data.append(array(arg))
else:
data.append(arg)
if len(data) == 0:
raise ChacoShellError("Insufficient data for plot.")
# 1D array(s)
if len(data[0].shape) == 1:
if len(data) == 1:
# Only a single array was provided
index_ds = ArrayDataSource(arange(len(data[0])), sort_order="ascending")
value_ds = ArrayDataSource(data[0], sort_order="none")
return [(index_ds, value_ds)]
else:
# multiple arrays were provided
index_ds = ArrayDataSource(data[0], sort_order=index_sort)
return [(index_ds, ArrayDataSource(v, sort_order="none")) for v in data[1:]]
# 2D arrays
elif len(data[0].shape) == 2:
sources = []
# Loop over all the 2D arrays
for ary in data:
if ary.shape[0] > ary.shape[1]:
index_ary = ary[:, 0]
value_arrays = ary[:, 1:]
else:
index_ary = ary[0]
value_arrays = transpose(ary[1:])
index_ds = ArrayDataSource(index_ary, sort_order=index_sort)
sources.extend([(index_ds, ArrayDataSource(v, sort_order="none")) for v in value_arrays])
return sources
# Not a two-dimensional array, error.
else:
raise ChacoShellError(
"Unable to create plot data sources from array of shape " +
str(data[1].shape) + ".")
#-----------------------------------------------------------------------------
# Plot commands for matlab-compatible plot() function
#-----------------------------------------------------------------------------
# Regular expressions for parsing the format string
color_re = re.compile('[ymcrgbwk]')
color_trans = {
'y': 'yellow',
'm': 'magenta',
'c': 'cyan',
'r': 'red',
'g': 'green',
'b': 'blue',
'w': 'white',
'k': 'black'
}
# This one isn't quite right:
marker_re = re.compile('[ox+s^v]|(?:[^-])[.]')
marker_trans = {
'.': 'dot',
'o': 'circle',
'x': 'cross',
'+': 'plus',
's': 'square',
'^': 'triangle',
'v': 'inverted_triangle'
}
line_re = re.compile('--|-\.|[-:]')
line_trans = {
'-': 'solid',
':': 'dot',
'-.': 'dot dash',
'--': 'dash'
}
def _process_format(format):
"""
Converts a format string into a (color, line, marker, marker_color) tuple.
"""
if format == '':
return ('black', 'solid', None, None)
color, line, marker, marker_color = 'black', None, None, None
m = color_re.findall(format)
if len(m) > 0:
color = marker_color = color_trans[m[0]]
if len(m) > 1:
marker_color = color_trans[m[1]]
m = marker_re.findall(format)
# The -1 takes care of 'r.', etc:
if len(m) > 0:
marker = marker_trans[m[0][-1]]
m = line_re.findall(format)
if len(m):
line = line_trans[m[0]]
return (color, line, marker, marker_color)
def _process_group(group, plot_data=None):
""" Returns a (x_1D, y_1D, format_str) tuple from an input tuple
of 1 to 3 elements: (x,y,format_str).
A PlotData object can be optionally provided to disambiguate the cases
when exactly two strings are passed in. The two strings could be the
names of the x and y datasources, or they could be the name of the y
datasource and a format string. By checking the second string against
the plot_data's list of datasources, the method can determine what it is meant
to be.
"""
# Interpret and split the 'group' tuple into x, y, and plotinfo
plotinfo = ''
if len(group) == 1:
y = group[0]
y_data = plot_data.get_data(y)
x = plot_data.set_data("", arange(len(y_data)), generate_name=True)
elif len(group) == 2:
# There are two possibilities here; a single y was provided along
# with a format string, or an x and y were provided. If PlotData
# was provided, use that to disambiguate; otherwise, assume that the
# second string is a format string.
if isinstance(group[1], six.string_types):
if plot_data and group[1] in plot_data.list_data():
x = group[0]
y = group[1]
else:
plotinfo = group[1]
y = group[0]
y_data = plot_data.get_data(y)
x = plot_data.set_data("", arange(len(y_data)), generate_name=True)
else:
x, y = group
elif len(group) == 3:
x, y, plotinfo = group
else:
raise ChacoShellError("Found too many elements in group while" \
" constructing plot.")
return x, y, plotinfo
def _check_sort_order(data):
diffs = data[1:] - data[:-1]
if all(diffs >= 0):
return "ascending"
elif all(diffs <= 0):
return "descending"
else:
return "none"
def do_plot(plotdata, active_plot, *data_and_formats, **kwtraits):
""" Takes a list of data (arrays or names) and format string arguments
and creates new plots on the active_plot. Returns a list of plot names
on the active plot.
"""
# The list of data and formats is broken up by format strings,
# so we break it up by arguments that are strings.
cur_group = []
groups = []
valid_names = plotdata.list_data()
for arg in data_and_formats:
if not isinstance(arg, six.string_types):
# an array was passed in
cur_group.append(plotdata.set_data("", arg, generate_name=True))
elif arg in valid_names:
# the name of an existing plotdata item was passed in
cur_group.append(arg)
else:
# String that is not in plotdata is interpreted as a format
# string, thereby terminating this group
cur_group.append(arg)
groups.append(cur_group)
cur_group = []
if len(cur_group) > 0:
groups.append(cur_group)
# Process the list of groups and create a list of plots;
# broadcast the keyword traits to all of them.
plots = []
for group in groups:
x, y, format_str = _process_group(group, plot_data = plotdata)
linecolor, line, marker, markercolor = _process_format(format_str)
plot_type = []
format = kwtraits.copy()
if line is not None:
plot_type.append("line")
format["line_style"] = line
format["color"] = linecolor
if marker is not None:
plot_type.append("scatter")
format["marker"] = marker
format["color"] = markercolor
x_sort_order = _check_sort_order(plotdata.get_data(x))
plots.extend(active_plot.plot((x,y), type=",".join(plot_type), **format))
# Set the sort order
x_ds = active_plot.datasources[x]
if isinstance(x_ds, ArrayDataSource):
x_ds.sort_order = x_sort_order
# Check to see if the active_plot has a highlighter tool already; if not,
# then add it.
for tool in active_plot.tools:
if isinstance(tool, HighlightTool):
break
else:
active_plot.tools.append(HighlightTool(active_plot))
return plots
def do_imread(*data, **kwargs):
""" Returns image file as array. """
# Check to see if the data given is either a file path or a file object
if isinstance(data[0], six.string_types) or isinstance(data[0], io.IOBase):
return ImageData.fromfile(data[0])
else:
raise ValueError("do_imread takes a string filename")
def do_imshow(plotdata, active_plot, *data, **kwargs):
""" Creates an image plot on the active plot, given either
a filename or data.
"""
if len(data) != 1:
raise ValueError("do_imshow takes one data source")
x = None
y = None
try:
z = _get_or_create_plot_data(data[0], plotdata)
except ValueError:
# z is the name of the file
# create plot data
image = do_imread(data[0], *data, **kwargs)
z = plotdata.set_data("", image, generate_name=True)
plot_list = [active_plot.img_plot(z, xbounds=x, ybounds=y, **kwargs)]
return plot_list
def do_pcolor(plotdata, colormap, active_plot, *data, **kwargs ):
""" Creates a pseudocolor image plot on the active plot, given a 2-D
scalar data and a colormap.
"""
# if we get just one data source, it is assumed to be the scalar field
if len(data) == 1:
x = None
y = None
z = _get_or_create_plot_data(data[0], plotdata)
# three data sources means we got x-y grid data of some sort, too
elif len(data) == 3:
x = _get_or_create_plot_data(data[0], plotdata)
y = _get_or_create_plot_data(data[1], plotdata)
z = _get_or_create_plot_data(data[2], plotdata)
else:
raise ValueError("do_pcolor takes one or three data sources")
plot_list = [active_plot.img_plot(z, xbounds=x, ybounds=y,
colormap=colormap, **kwargs)]
return plot_list
def do_contour(plotdata, colormap, active_plot, type, *data, **kwargs ):
""" Creates a contour plot on the active plot, given a 2-D
scalar data and a colormap.
"""
# if we get just one data source, it is assumed to be the scalar field
if len(data) == 1:
x = None
y = None
z = _get_or_create_plot_data(data[0], plotdata)
# three data sources means we got x-y grid data of some sort, too
elif len(data) == 3:
x = _get_or_create_plot_data(data[0], plotdata)
y = _get_or_create_plot_data(data[1], plotdata)
z = _get_or_create_plot_data(data[2], plotdata)
else:
raise ValueError("do_contour takes one or three data sources")
# we have to do slightly different calls here because of the different
# handling of colormaps
if type is 'poly':
plot_list = [active_plot.contour_plot(z, type, xbounds=x, ybounds=y,
poly_cmap=colormap,
**kwargs)]
else:
plot_list = [active_plot.contour_plot(z, type, xbounds=x, ybounds=y,
colors=colormap,
**kwargs)]
return plot_list
def _get_or_create_plot_data(data, plotdata):
"""Create a new name for `data` if necessary, or check it is a valid name.
"""
valid_names = plotdata.list_data()
if not isinstance(data, six.string_types):
name = plotdata.set_data("", data, generate_name=True)
else:
if data not in valid_names:
msg = '{} is not an existing name for plot data'
raise ValueError(msg.format(data))
name = data
return name
# EOF