Repository URL to install this package:
|
Version:
6.1.1 ▾
|
#------------------------------------------------------------------------------
#
# Copyright (c) 2005, Enthought, Inc.
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in enthought/LICENSE.txt and may be redistributed only
# under the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
#
# Author: David C. Morrill
# Date: 07/01/2005
#
#------------------------------------------------------------------------------
""" Defines the table grid model used by the table editor based on the PyFace
grid control.
"""
#-------------------------------------------------------------------------
# Imports:
#-------------------------------------------------------------------------
from __future__ import absolute_import
import logging
import wx
import wx.grid as wxg
from traits.api \
import HasPrivateTraits, Any, Str, Instance, Event, Bool, \
on_trait_change
from traitsui.api \
import View, Item, Editor
from traitsui.editors.table_editor \
import ReversedList
from traitsui.table_filter \
import TableFilter
from traitsui.ui_traits \
import SequenceTypes
from pyface.ui.wx.grid.api \
import GridModel, GridSortEvent
from pyface.ui.wx.grid.trait_grid_cell_adapter \
import TraitGridCellAdapter
from pyface.timer.api \
import do_later
logger = logging.getLogger(__name__)
#-------------------------------------------------------------------------
# 'TraitGridSelection' class:
#-------------------------------------------------------------------------
class TraitGridSelection(HasPrivateTraits):
""" Structure for holding specification information.
"""
# The selected object
obj = Any
# The specific trait selected on the object
name = Str
#-------------------------------------------------------------------------
# 'TableModel' class:
#-------------------------------------------------------------------------
class TableModel(GridModel):
""" Model for table data.
"""
#-------------------------------------------------------------------------
# Trait definitions:
#-------------------------------------------------------------------------
# The editor that created this model
editor = Instance(Editor)
# The current filter
filter = Instance(TableFilter, allow_none=True)
# Current filter summary message
filter_summary = Str('All items')
# Display the table items in reverse order?
reverse = Bool(False)
# Event fired when the table has been sorted
sorted = Event
# The current 'auto_add' row
auto_add_row = Any
#-------------------------------------------------------------------------
# 'object' interface:
#-------------------------------------------------------------------------
#-------------------------------------------------------------------------
# Initializes the object:
#-------------------------------------------------------------------------
def __init__(self, **traits):
""" Initializes the object.
"""
super(TableModel, self).__init__(**traits)
# Attach trait handlers to the list object:
editor = self.editor
object = editor.context_object
name = ' ' + editor.extended_name
# Set up listeners for any of the model data changing:
object.on_trait_change(self._on_data_changed, name, dispatch='ui')
object.on_trait_change(self.fire_content_changed, name + '.-',
dispatch='ui')
# Set up listeners for any column definitions changing:
editor.on_trait_change(self.update_columns, 'columns',
dispatch='ui')
editor.on_trait_change(self.update_columns, 'columns_items',
dispatch='ui')
# Initialize the current filter from the editor's default filter:
self.filter = editor.filter
# If we are using 'auto_add' mode, create the first 'auto_add' row:
if editor.auto_add:
self.auto_add_row = row = editor.create_new_row()
if row is not None:
row.on_trait_change(self.on_auto_add_row, dispatch='ui')
#-- TableModel Interface -------------------------------------------------
#-------------------------------------------------------------------------
# Disposes of the model when it is no longer needed:
#-------------------------------------------------------------------------
def dispose(self):
""" Disposes of the model when it is no longer needed.
"""
editor = self.editor
object = editor.context_object
name = ' ' + editor.extended_name
# Remove listeners for any of the model data changing:
object.on_trait_change(self._on_data_changed, name, remove=True)
object.on_trait_change(self.fire_content_changed, name + '.-',
remove=True)
# Remove listeners for any column definitions changing:
editor.on_trait_change(self.update_columns, 'columns', remove=True)
editor.on_trait_change(self.update_columns, 'columns_items',
remove=True)
# Make sure we have removed listeners from the current filter also:
if self.filter is not None:
self.filter.on_trait_change(self._filter_modified, remove=True)
# Clean-up any links that should be broken:
self.editor = None
#-------------------------------------------------------------------------
# Returns all model items matching the current filter:
#-------------------------------------------------------------------------
def get_filtered_items(self):
""" Returns all model items matching the current filter.
"""
return self.__filtered_items()
#-------------------------------------------------------------------------
# Returns a single specified item from those items matching the current
# filter:
#-------------------------------------------------------------------------
def get_filtered_item(self, index=0):
""" Returns a single specified item from those items matching the
current filter.
"""
try:
return self.__filtered_items()[index]
except:
logger.error('TableModel error: Request for invalid row %d out of '
'%d' % (index, len(self.__filtered_items())))
return None
#-------------------------------------------------------------------------
# Returns the raw, unfiltered index corresponding to a specified filtered
# index:
#-------------------------------------------------------------------------
def raw_index_of(self, row):
""" Returns the raw, unfiltered index corresponding to a specified
filtered index.
"""
if self._filtered_cache is None:
return row
return self.editor.filtered_indices[row]
#-------------------------------------------------------------------------
# Inserts an object after a specified filtered index:
#-------------------------------------------------------------------------
def insert_filtered_item_after(self, index, item):
""" Inserts an object after a specified filtered index.
"""
mapped_index = 0
n = len(self.editor.filtered_indices)
if index >= n:
if (index != 0) or (n != 0):
raise IndexError
elif index >= 0:
mapped_index = self.editor.filtered_indices[index] + 1
self.__items().insert(mapped_index, item)
sorted = self._sort_model()
if sorted:
mapped_index = self.__items().index(item)
self._filtered_cache = None
return (mapped_index, sorted)
#-------------------------------------------------------------------------
# Deletes the object at the specified filtered index:
#-------------------------------------------------------------------------
def delete_filtered_item_at(self, index):
""" Deletes the object at the specified filtered index.
"""
if index >= len(self.editor.filtered_indices):
raise IndexError
mapped_index = self.editor.filtered_indices[index]
items = self.__items()
object = items[mapped_index]
del items[mapped_index]
self._filtered_cache = None
return (mapped_index, object)
#-------------------------------------------------------------------------
# Updates the table view when columns have been changed:
#-------------------------------------------------------------------------
def update_columns(self):
""" Updates the table view when columns have been changed.
"""
self._columns = None
self.fire_structure_changed()
self.editor.refresh()
#-------------------------------------------------------------------------
# Resets any sorting being performed on the underlying model:
#-------------------------------------------------------------------------
def no_column_sort(self):
""" Resets any sorting being performed on the underlying model.
"""
self._sorter = self._filtered_cache = None
self.column_sorted = GridSortEvent(index=-1)
#self.fire_structure_changed()
#-- Event Handlers -------------------------------------------------------
#-------------------------------------------------------------------------
# Handles the contents of the filter being changed:
#-------------------------------------------------------------------------
@on_trait_change('filter.+')
def _filter_modified(self):
""" Handles the contents of the filter being changed.
"""
self._filtered_cache = None
self.fire_structure_changed()
self.editor.filter_modified()
#-------------------------------------------------------------------------
# Handles the grid firing a 'click' event:
#-------------------------------------------------------------------------
def _click_changed(self, event):
""" Handles the grid firing a 'click' event.
"""
row, col = event
# Fire the same event on the editor after mapping it to a model object
# and column name:
object = self.get_filtered_item(row)
column = self.__get_column(col)
self.editor.click = (object, column)
# Check to see if the column has a view to display:
view = column.get_view(object)
if view is not None:
column.get_object(object).edit_traits(
view=view,
parent=self._bounds_for(row, col))
# Invoke the column's click handler:
column.on_click(object)
#-------------------------------------------------------------------------
# Handles the grid firing a 'dclick' event:
#-------------------------------------------------------------------------
def _dclick_changed(self, event):
""" Handles the grid firing a 'dclick' event.
"""
row, col = event
# Fire the same event on the editor after mapping it to a model object
# and column name:
object = self.get_filtered_item(row)
column = self.__get_column(col)
self.editor.dclick = (object, column)
# Invoke the column's double-click handler:
column.on_dclick(object)
#-------------------------------------------------------------------------
# Handles the user modifying the current 'auto_add' mode row:
#-------------------------------------------------------------------------
def on_auto_add_row(self):
""" Handles the user modifying the current 'auto_add' mode row.
"""
object = self.auto_add_row
object.on_trait_change(self.on_auto_add_row, remove=True)
self.auto_add_row = row = self.editor.create_new_row()
if row is not None:
row.on_trait_change(self.on_auto_add_row, dispatch='ui')
do_later(self.editor.add_row, object,
len(self.get_filtered_items()) - 2)
#-- GridModel Interface --------------------------------------------------
def get_column_count(self):
""" Returns the number of columns for this table.
"""
return len(self.__get_columns())
def get_column_name(self, index):
""" Returns the label of the column specified by the (zero-based) index.
"""
return self.__get_column(index).get_label()
def get_column_size(self, index):
""" Returns the size in pixels of the column indexed by *index*.
A value of -1 or None means to use the default.
"""
return self.__get_column(index).get_width()
def get_cols_drag_value(self, cols):
""" Returns the value to use when the specified columns are dragged or
copied and pasted. The parameter *cols* is a list of column indexes.
"""
return [self.__get_data_column(col) for col in cols]
def get_cols_selection_value(self, cols):
""" Returns a list of TraitGridSelection objects containing the
objects corresponding to the grid rows and the traits corresponding
to the specified columns.
"""
values = []
for obj in self.__items(False):
values.extend([TraitGridSelection(
obj=obj,
name=self.__get_column_name(col))
for col in cols])
return values
def sort_by_column(self, col, reverse=False):
""" Sorts the model data by the column indexed by *col*.
"""
# Make sure we allow sorts by column:
factory = self.editor.factory
if not factory.sortable:
return
# Flush the object cache:
self._filtered_cache = None
# Cache the sorting information for later:
self._sorter = self.__get_column(col).key
self._reverse = reverse
# If model sorting is requested, do it now:
self._sort_model()
# Indicate the we have been sorted:
self.sorted = True
self.column_sorted = GridSortEvent(index=col,
reversed=reverse)
def is_column_read_only(self, index):
""" Returns True if the column specified by the zero-based *index* is
read-only.
"""
return (not self.__get_column(index).editable)
def get_row_count(self):
""" Return the number of rows for this table.
"""
return len(self.__filtered_items())
def get_row_name(self, index):
""" Return the name of the row specified by the (zero-based) *index*.
"""
return '<undefined>'
def get_rows_drag_value(self, rows):
""" Returns the value to use when the specified rows are dragged or
copied and pasted. The parameter *rows* is a list of row indexes.
If there is only one row listed, then return the corresponding trait
object. If more than one row is listed, then return a list of objects.
"""
items = self.__filtered_items()
return [items[row] for row in rows]
def get_rows_selection_value(self, rows):
""" Returns a list of TraitGridSelection objects containing the
object corresponding to the selected rows.
"""
items = self.__filtered_items()
return [TraitGridSelection(obj=items[row]) for row in rows]
def is_row_read_only(self, index):
""" Returns True if the row specified by the zero-based *index* is
read-only.
"""
return False
def get_cell_editor(self, row, col):
""" Returns the editor for the specified cell.
"""
if self.editor is None:
return None
column = self.__get_column(col)
object = self.get_filtered_item(row)
editor = column.get_editor(object)
if editor is None:
return None
editor._ui = self.editor.ui
target, name = column.target_name(object)
return TraitGridCellAdapter(editor, target, name, '',
context=self.editor.ui.context,
style=column.get_style(object),
width=column.get_edit_width(object),
height=column.get_edit_height(object))
def get_cell_renderer(self, row, col):
""" Returns the renderer for the specified cell.
"""
return self.__get_column(col).get_renderer(
self.get_filtered_item(row))
def get_cell_drag_value(self, row, col):
""" Returns the value to use when the specified cell is dragged or
copied and pasted.
"""
return self.__get_column(col).get_drag_value(
self.get_filtered_item(row))
def get_cell_selection_value(self, row, col):
""" Returns a TraitGridSelection object specifying the data stored
in the table at (*row*, *col*).
"""
return TraitGridSelection(obj=self.get_filtered_item(row),
name=self.__get_column_name(col))
def resolve_selection(self, selection_list):
""" Returns a list of (row, col) grid-cell coordinates that
correspond to the objects in *selection_list*. For each coordinate,
if the row is -1, it indicates that the entire column is selected.
Likewise coordinates with a column of -1 indicate an entire row
that is selected. For the TableModel, the objects in
*selection_list* must be TraitGridSelection objects.
"""
items = self.__filtered_items()
cells = []
for selection in selection_list:
row = -1
if selection.obj is not None:
try:
row = items.index(selection.obj)
except ValueError:
continue
column = -1
if selection.name != '':
column = self._get_column_index_by_trait(selection.name)
if column is None:
continue
cells.append((row, column))
return cells
def get_cell_context_menu(self, row, col):
""" Returns a Menu object that generates the appropriate context
menu for this cell.
"""
column = self.__get_column(col)
menu = column.get_menu(self.get_filtered_item(row))
editor = self.editor
if menu is None:
menu = editor.factory.menu
if menu is None:
menu_name = editor.factory.menu_name
if menu_name:
menu = getattr(self.editor.object, menu_name, None)
if menu is not None:
editor.prepare_menu(row, column)
return (menu, editor)
return None
def get_value(self, row, col):
""" Returns the value stored in the table at (*row*, *col*).
"""
object = self.get_filtered_item(row)
if object is self.auto_add_row:
return ''
value = self.__get_column(col).get_value(object)
formats = self.__get_column_formats(col)
if (value is not None) and (formats is not None):
format = formats.get(type(value))
if format is not None:
try:
if callable(format):
value = format(value)
else:
value = format % value
except:
pass
return value
def is_valid_cell_value(self, row, col, value):
""" Tests whether *value* is valid for the cell at (*row*, *col*).
Returns True if value is acceptable, and False otherwise. """
return self.__get_column(col).is_droppable(
self.get_filtered_item(row), value)
def is_cell_empty(self, row, col):
""" Returns True if the cell at (*row*, *col*) has a None value, and
False otherwise.
"""
return (self.get_value(row, col) is None)
def is_cell_read_only(self, row, col):
""" Returns True if the cell at (*row*, *col*) is read-only, and False
otherwise.
"""
return (not self.__get_column(col).is_editable(
self.get_filtered_item(row)))
def get_cell_bg_color(self, row, col):
""" Returns a wxColour object specifying the background color
of the specified cell.
"""
return self.__get_column(col).get_cell_color(
self.get_filtered_item(row))
def get_cell_text_color(self, row, col):
""" Returns a wxColour object specifying the text color of the
specified cell.
"""
column = self.__get_column(col)
item = self.get_filtered_item(row)
return column.get_text_color(item)
def get_cell_font(self, row, col):
""" Returns a wxFont object specifying the font of the specified cell.
"""
return self.__get_column(col).get_text_font(
self.get_filtered_item(row))
def get_cell_halignment(self, row, col):
""" Returns a string specifying the horizontal alignment of the
specified cell.
Returns 'left' for left alignment, 'right' for right alignment,
or 'center' for center alignment.
"""
return self.__get_column(col).get_horizontal_alignment(
self.get_filtered_item(row))
def get_cell_valignment(self, row, col):
""" Returns a string specifying the vertical alignment of the
specified cell.
Returns 'top' for top alignment, 'bottom' for bottom alignment,
or 'center' for center alignment.
"""
return self.__get_column(col).get_vertical_alignment(
self.get_filtered_item(row))
#-------------------------------------------------------------------------
# Protected 'GridModel' interface:
#-------------------------------------------------------------------------
def _insert_rows(self, pos, num_rows):
""" Inserts *num_rows* at *pos*; fires an event only if a factory
method for new rows is defined or the model is not empty. Otherwise,
it returns 0.
"""
count = 0
factory = self.editor.factory.row_factory
if factory is None:
items = self.__items(False)
if len(items) > 0:
factory = items[0].__class__
if factory is not None:
new_data = [x for x in [factory() for i in range(num_rows)]
if x is not None]
if len(new_data) > 0:
count = self._insert_rows_into_model(pos, new_data)
self.rows_added = ('added', pos, new_data)
return count
def _delete_rows(self, pos, num_rows):
""" Removes rows *pos* through *pos* + *num_rows* from the model.
"""
row_count = self.get_rows_count()
if (pos + num_rows) > row_count:
num_rows = row_count - pos
return self._delete_rows_from_model(pos, num_rows)
def _set_value(self, row, col, value):
""" Sets the value of the cell at (*row*, *col*) to *value*.
Raises a ValueError if the value is vetoed or the cell at
the specified position does not exist.
"""
new_rows = 0
column = self.__get_column(col)
obj = None
try:
obj = self.get_filtered_item(row)
except:
# Add a new row:
new_rows = self._insert_rows(self.get_row_count(), 1)
if new_rows > 0:
# Now set the value on the new object:
try:
obj = self.get_filtered_item(self.get_row_count() - 1)
except:
# fixme: what do we do in this case? veto the set somehow?
# raise an exception?
pass
if obj is not None:
self._set_data_on_row(obj, column, value)
return new_rows
def _move_column(self, frm, to):
""" Moves a specified **frm** column to before the specified **to**
column. Returns **True** if successful; **False** otherwise.
"""
to_column = None
if to < len(self.__get_columns()):
to_column = self.__get_column(to)
return self.editor.move_column(self.__get_column(frm), to_column)
#-------------------------------------------------------------------------
# Protected interface:
#-------------------------------------------------------------------------
def _set_data_on_row(self, row, column, value):
""" Sets the cell specified by (*row*, *col*) to *value, which
can be either a member of the row object, or a no-argument method
on that object.
"""
column.set_value(row, value)
def _insert_rows_into_model(self, pos, new_data):
""" Inserts the given new rows into the model.
"""
raw_pos = self.raw_index_of(pos)
self.__items()[raw_pos: raw_pos] = new_data
def _delete_rows_from_model(self, pos, num_rows):
""" Deletes the specified rows from the model.
"""
raw_rows = sorted([self.raw_index_of(i)
for i in range(pos, pos + num_rows)])
raw_rows.reverse()
items = self.__items()
for row in raw_rows:
del items[row]
return num_rows
#-------------------------------------------------------------------------
# Trait event handlers:
#-------------------------------------------------------------------------
def _on_data_changed(self):
""" Forces the grid to refresh when the underlying list changes.
"""
# Invalidate the current cache (if any):
self._filtered_cache = None
self.fire_structure_changed()
def _mouse_cell_changed(self, new):
""" Handles the user mousing over a specified cell.
"""
row, col = new
column = self.__get_column(col)
object = self.get_filtered_item(row)
# Update the tooltip if necessary:
tooltip = column.get_tooltip(object)
if tooltip != self._tooltip:
self._tooltip = tooltip
self.editor.grid._grid_window.SetToolTip(wx.ToolTip(tooltip))
if column.is_auto_editable(object):
x, y, dx, dy = self._bounds_for(row, col)
if column.is_editable(object):
view = View(Item(name=column.name,
editor=column.get_editor(object),
style=column.get_style(object),
show_label=False,
padding=-4),
kind='info',
width=dx,
height=dy)
else:
view = column.get_view(object)
if view is None:
return
column.get_object(object).edit_traits(
view=view,
parent=(x, y, dx, dy))
#-------------------------------------------------------------------------
# Private interface:
#-------------------------------------------------------------------------
def _bounds_for(self, row, col):
""" Returns the coordinates and size of the specified cell in the form:
( x, y, dx, dy ).
"""
grid = self.editor.grid
coords = wxg.GridCellCoords(row, col)
x, y, dx, dy = grid._grid.BlockToDeviceRect(coords, coords)
x, y = grid._grid_window.ClientToScreenXY(x, y)
return (x, y, dx, dy)
def _sort_model(self):
""" Sorts the underlying model if that is what the user requested.
"""
editor = self.editor
sorted = (editor.factory.sort_model and (self._sorter is not None))
if sorted:
items = self.__items(False)[:]
items.sort(key=self._sorter)
if self.reverse ^ self._reverse:
items.reverse()
editor.value = items
return sorted
def __items(self, ordered=True):
""" Returns the raw list of model objects.
"""
result = self.editor.value
if not isinstance(result, SequenceTypes):
return [result]
if ordered and self.reverse:
return ReversedList(result)
return result
def __filtered_items(self):
""" Returns the list of all model objects that pass the current filter.
"""
fc = self._filtered_cache
if fc is None:
items = self.__items()
filter = self.filter
if filter is None:
nitems = [nitem for nitem in enumerate(items)]
self.filter_summary = 'All %s items' % len(nitems)
else:
if not callable(filter):
filter = filter.filter
nitems = [nitem for nitem in enumerate(items)
if filter(nitem[1])]
self.filter_summary = '%s of %s items' % (len(nitems),
len(items))
sorter = self._sorter
if sorter is not None:
nitems.sort(key=lambda x: sorter(x[1]))
if self._reverse:
nitems.reverse()
self.editor.filtered_indices = [x[0] for x in nitems]
self._filtered_cache = fc = [x[1] for x in nitems]
if self.auto_add_row is not None:
self._filtered_cache.append(self.auto_add_row)
return fc
def __get_data_column(self, col):
""" Returns a list of model data from the column indexed by *col*.
"""
column = self.__get_column(col)
return [column.get_value(item) for item in self.__filtered_items()]
def __get_columns(self):
columns = self._columns
if columns is None:
self._columns = columns = [c for c in self.editor.columns
if c.visible]
return columns
def __get_column(self, col):
try:
return self.__get_columns()[col]
except:
return self.__get_columns()[0]
def __get_column_name(self, col):
return self.__get_column(col).name
def __get_column_formats(self, col):
return None # Not used/implemented currently
def _get_column_index_by_trait(self, name):
for i, col in enumerate(self.__get_columns()):
if name == col.name:
return i