from __future__ import absolute_import
import six
from six.moves import range, zip
from pytools import Record
class Row(Record):
pass
class DataTable:
"""An in-memory relational database table."""
def __init__(self, column_names, column_data=None):
"""Construct a new table, with the given C{column_names}.
@arg column_names: An indexable of column name strings.
@arg column_data: None or a list of tuples of the same length as
C{column_names} indicating an initial set of data.
"""
if column_data is None:
self.data = []
else:
self.data = column_data
self.column_names = column_names
self.column_indices = dict(
(colname, i) for i, colname in enumerate(column_names))
if len(self.column_indices) != len(self.column_names):
raise RuntimeError("non-unique column names encountered")
def __bool__(self):
return bool(self.data)
def __len__(self):
return len(self.data)
def __iter__(self):
return self.data.__iter__()
def __str__(self):
"""Return a pretty-printed version of the table."""
def col_width(i):
width = len(self.column_names[i])
if self:
width = max(width, max(len(str(row[i])) for row in self.data))
return width
col_widths = [col_width(i) for i in range(len(self.column_names))]
def format_row(row):
return "|".join([str(cell).ljust(col_width)
for cell, col_width in zip(row, col_widths)])
lines = [format_row(self.column_names),
"+".join("-"*col_width for col_width in col_widths)] + \
[format_row(row) for row in self.data]
return "\n".join(lines)
def insert(self, **kwargs):
values = [None for i in range(len(self.column_names))]
for key, val in six.iteritems(kwargs):
values[self.column_indices[key]] = val
self.insert_row(tuple(values))
def insert_row(self, values):
assert isinstance(values, tuple)
assert len(values) == len(self.column_names)
self.data.append(values)
def insert_rows(self, rows):
for row in rows:
self.insert_row(row)
def filtered(self, **kwargs):
if not kwargs:
return self
criteria = tuple(
(self.column_indices[key], value)
for key, value in six.iteritems(kwargs))
result_data = []
for row in self.data:
satisfied = True
for idx, val in criteria:
if row[idx] != val:
satisfied = False
break
if satisfied:
result_data.append(row)
return DataTable(self.column_names, result_data)
def get(self, **kwargs):
filtered = self.filtered(**kwargs)
if not filtered:
raise RuntimeError("no matching entry for get()")
if len(filtered) > 1:
raise RuntimeError("more than one matching entry for get()")
return Row(dict(list(zip(self.column_names, filtered.data[0]))))
def clear(self):
del self.data[:]
def copy(self):
"""Make a copy of the instance, but leave individual rows untouched.
If the rows are modified later, they will also be modified in the copy.
"""
return DataTable(self.column_names, self.data[:])
def deep_copy(self):
"""Make a copy of the instance down to the row level.
The copy's rows may be modified independently from the original.
"""
return DataTable(self.column_names, [row[:] for row in self.data])
def sort(self, columns, reverse=False):
col_indices = [self.column_indices[col] for col in columns]
def mykey(row):
return tuple(
row[col_index]
for col_index in col_indices)
self.data.sort(reverse=reverse, key=mykey)
def aggregated(self, groupby, agg_column, aggregate_func):
gb_indices = [self.column_indices[col] for col in groupby]
agg_index = self.column_indices[agg_column]
first = True
result_data = []
# to pacify pyflakes:
last_values = None
agg_values = None
for row in self.data:
this_values = tuple(row[i] for i in gb_indices)
if first or this_values != last_values:
if not first:
result_data.append(last_values + (aggregate_func(agg_values),))
agg_values = [row[agg_index]]
last_values = this_values
first = False
else:
agg_values.append(row[agg_index])
if not first and agg_values:
result_data.append(this_values + (aggregate_func(agg_values),))
return DataTable(
[self.column_names[i] for i in gb_indices] + [agg_column],
result_data)
def join(self, column, other_column, other_table, outer=False):
"""Return a tabled joining this and the C{other_table} on C{column}.
The new table has the following columns:
- C{column}, titled the same as in this table.
- the columns of this table, minus C{column}.
- the columns of C{other_table}, minus C{other_column}.
Assumes both tables are sorted ascendingly by the column
by which they are joined.
""" # pylint:disable=too-many-locals,too-many-branches
def without(indexable, idx):
return indexable[:idx] + indexable[idx+1:]
this_key_idx = self.column_indices[column]
other_key_idx = other_table.column_indices[other_column]
this_iter = self.data.__iter__()
other_iter = other_table.data.__iter__()
result_columns = [self.column_names[this_key_idx]] + \
without(self.column_names, this_key_idx) + \
without(other_table.column_names, other_key_idx)
result_data = []
this_row = next(this_iter)
other_row = next(other_iter)
this_over = False
other_over = False
while True:
this_batch = []
other_batch = []
if this_over:
run_other = True
elif other_over:
run_this = True
else:
this_key = this_row[this_key_idx]
other_key = other_row[other_key_idx]
run_this = this_key < other_key
run_other = this_key > other_key
if this_key == other_key:
run_this = run_other = True
if run_this and not this_over:
key = this_key
while this_row[this_key_idx] == this_key:
this_batch.append(this_row)
try:
this_row = next(this_iter)
except StopIteration:
this_over = True
break
else:
if outer:
this_batch = [(None,) * len(self.column_names)]
if run_other and not other_over:
key = other_key
while other_row[other_key_idx] == other_key:
other_batch.append(other_row)
try:
other_row = next(other_iter)
except StopIteration:
other_over = True
break
else:
if outer:
other_batch = [(None,) * len(other_table.column_names)]
for this_batch_row in this_batch:
for other_batch_row in other_batch:
result_data.append((key,)
+ without(this_batch_row, this_key_idx)
+ without(other_batch_row, other_key_idx))
if outer:
if this_over and other_over:
break
else:
if this_over or other_over:
break
return DataTable(result_columns, result_data)
def restricted(self, columns):
col_indices = [self.column_indices[col] for col in columns]
return DataTable(columns,
[[row[i] for i in col_indices] for row in self.data])
def column_data(self, column):
col_index = self.column_indices[column]
return [row[col_index] for row in self.data]
def write_csv(self, filelike, **kwargs):
from csv import writer
csvwriter = writer(filelike, **kwargs)
csvwriter.writerow(self.column_names)
csvwriter.writerows(self.data)