Repository URL to install this package:
|
Version:
0.13.2 ▾
|
"""
Components for parsing variable assignments and internally representing plot data.
"""
from __future__ import annotations
from collections.abc import Mapping, Sized
from typing import cast
import warnings
import pandas as pd
from pandas import DataFrame
from seaborn._core.typing import DataSource, VariableSpec, ColumnName
from seaborn.utils import _version_predates
class PlotData:
"""
Data table with plot variable schema and mapping to original names.
Contains logic for parsing variable specification arguments and updating
the table with layer-specific data and/or mappings.
Parameters
----------
data
Input data where variable names map to vector values.
variables
Keys are names of plot variables (x, y, ...) each value is one of:
- name of a column (or index level, or dictionary entry) in `data`
- vector in any format that can construct a :class:`pandas.DataFrame`
Attributes
----------
frame
Data table with column names having defined plot variables.
names
Dictionary mapping plot variable names to names in source data structure(s).
ids
Dictionary mapping plot variable names to unique data source identifiers.
"""
frame: DataFrame
frames: dict[tuple, DataFrame]
names: dict[str, str | None]
ids: dict[str, str | int]
source_data: DataSource
source_vars: dict[str, VariableSpec]
def __init__(
self,
data: DataSource,
variables: dict[str, VariableSpec],
):
data = handle_data_source(data)
frame, names, ids = self._assign_variables(data, variables)
self.frame = frame
self.names = names
self.ids = ids
# The reason we possibly have a dictionary of frames is to support the
# Plot.pair operation, post scaling, where each x/y variable needs its
# own frame. This feels pretty clumsy and there are a bunch of places in
# the client code with awkard if frame / elif frames constructions.
# It would be great to have a cleaner abstraction here.
self.frames = {}
self.source_data = data
self.source_vars = variables
def __contains__(self, key: str) -> bool:
"""Boolean check on whether a variable is defined in this dataset."""
if self.frame is None:
return any(key in df for df in self.frames.values())
return key in self.frame
def join(
self,
data: DataSource,
variables: dict[str, VariableSpec] | None,
) -> PlotData:
"""Add, replace, or drop variables and return as a new dataset."""
# Inherit the original source of the upstream data by default
if data is None:
data = self.source_data
# TODO allow `data` to be a function (that is called on the source data?)
if not variables:
variables = self.source_vars
# Passing var=None implies that we do not want that variable in this layer
disinherit = [k for k, v in variables.items() if v is None]
# Create a new dataset with just the info passed here
new = PlotData(data, variables)
# -- Update the inherited DataSource with this new information
drop_cols = [k for k in self.frame if k in new.frame or k in disinherit]
parts = [self.frame.drop(columns=drop_cols), new.frame]
# Because we are combining distinct columns, this is perhaps more
# naturally thought of as a "merge"/"join". But using concat because
# some simple testing suggests that it is marginally faster.
frame = pd.concat(parts, axis=1, sort=False, copy=False)
names = {k: v for k, v in self.names.items() if k not in disinherit}
names.update(new.names)
ids = {k: v for k, v in self.ids.items() if k not in disinherit}
ids.update(new.ids)
new.frame = frame
new.names = names
new.ids = ids
# Multiple chained operations should always inherit from the original object
new.source_data = self.source_data
new.source_vars = self.source_vars
return new
def _assign_variables(
self,
data: DataFrame | Mapping | None,
variables: dict[str, VariableSpec],
) -> tuple[DataFrame, dict[str, str | None], dict[str, str | int]]:
"""
Assign values for plot variables given long-form data and/or vector inputs.
Parameters
----------
data
Input data where variable names map to vector values.
variables
Keys are names of plot variables (x, y, ...) each value is one of:
- name of a column (or index level, or dictionary entry) in `data`
- vector in any format that can construct a :class:`pandas.DataFrame`
Returns
-------
frame
Table mapping seaborn variables (x, y, color, ...) to data vectors.
names
Keys are defined seaborn variables; values are names inferred from
the inputs (or None when no name can be determined).
ids
Like the `names` dict, but `None` values are replaced by the `id()`
of the data object that defined the variable.
Raises
------
TypeError
When data source is not a DataFrame or Mapping.
ValueError
When variables are strings that don't appear in `data`, or when they are
non-indexed vector datatypes that have a different length from `data`.
"""
source_data: Mapping | DataFrame
frame: DataFrame
names: dict[str, str | None]
ids: dict[str, str | int]
plot_data = {}
names = {}
ids = {}
given_data = data is not None
if data is None:
# Data is optional; all variables can be defined as vectors
# But simplify downstream code by always having a usable source data object
source_data = {}
else:
source_data = data
# Variables can also be extracted from the index of a DataFrame
if isinstance(source_data, pd.DataFrame):
index = source_data.index.to_frame().to_dict("series")
else:
index = {}
for key, val in variables.items():
# Simply ignore variables with no specification
if val is None:
continue
# Try to treat the argument as a key for the data collection.
# But be flexible about what can be used as a key.
# Usually it will be a string, but allow other hashables when
# taking from the main data object. Allow only strings to reference
# fields in the index, because otherwise there is too much ambiguity.
# TODO this will be rendered unnecessary by the following pandas fix:
# https://github.com/pandas-dev/pandas/pull/41283
try:
hash(val)
val_is_hashable = True
except TypeError:
val_is_hashable = False
val_as_data_key = (
# See https://github.com/pandas-dev/pandas/pull/41283
# (isinstance(val, abc.Hashable) and val in source_data)
(val_is_hashable and val in source_data)
or (isinstance(val, str) and val in index)
)
if val_as_data_key:
val = cast(ColumnName, val)
if val in source_data:
plot_data[key] = source_data[val]
elif val in index:
plot_data[key] = index[val]
names[key] = ids[key] = str(val)
elif isinstance(val, str):
# This looks like a column name but, lookup failed.
err = f"Could not interpret value `{val}` for `{key}`. "
if not given_data:
err += "Value is a string, but `data` was not passed."
else:
err += "An entry with this name does not appear in `data`."
raise ValueError(err)
else:
# Otherwise, assume the value somehow represents data
# Ignore empty data structures
if isinstance(val, Sized) and len(val) == 0:
continue
# If vector has no index, it must match length of data table
if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series):
if isinstance(val, Sized) and len(data) != len(val):
val_cls = val.__class__.__name__
err = (
f"Length of {val_cls} vectors must match length of `data`"
f" when both are used, but `data` has length {len(data)}"
f" and the vector passed to `{key}` has length {len(val)}."
)
raise ValueError(err)
plot_data[key] = val
# Try to infer the original name using pandas-like metadata
if hasattr(val, "name"):
names[key] = ids[key] = str(val.name) # type: ignore # mypy/1424
else:
names[key] = None
ids[key] = id(val)
# Construct a tidy plot DataFrame. This will convert a number of
# types automatically, aligning on index in case of pandas objects
# TODO Note: this fails when variable specs *only* have scalars!
frame = pd.DataFrame(plot_data)
return frame, names, ids
def handle_data_source(data: object) -> pd.DataFrame | Mapping | None:
"""Convert the data source object to a common union representation."""
if isinstance(data, pd.DataFrame) or hasattr(data, "__dataframe__"):
# Check for pd.DataFrame inheritance could be removed once
# minimal pandas version supports dataframe interchange (1.5.0).
data = convert_dataframe_to_pandas(data)
elif data is not None and not isinstance(data, Mapping):
err = f"Data source must be a DataFrame or Mapping, not {type(data)!r}."
raise TypeError(err)
return data
def convert_dataframe_to_pandas(data: object) -> pd.DataFrame:
"""Use the DataFrame exchange protocol, or fail gracefully."""
if isinstance(data, pd.DataFrame):
return data
if not hasattr(pd.api, "interchange"):
msg = (
"Support for non-pandas DataFrame objects requires a version of pandas "
"that implements the DataFrame interchange protocol. Please upgrade "
"your pandas version or coerce your data to pandas before passing "
"it to seaborn."
)
raise TypeError(msg)
if _version_predates(pd, "2.0.2"):
msg = (
"DataFrame interchange with pandas<2.0.2 has some known issues. "
f"You are using pandas {pd.__version__}. "
"Continuing, but it is recommended to carefully inspect the results and to "
"consider upgrading."
)
warnings.warn(msg, stacklevel=2)
try:
# This is going to convert all columns in the input dataframe, even though
# we may only need one or two of them. It would be more efficient to select
# the columns that are going to be used in the plot prior to interchange.
# Solving that in general is a hard problem, especially with the objects
# interface where variables passed in Plot() may only be referenced later
# in Plot.add(). But noting here in case this seems to be a bottleneck.
return pd.api.interchange.from_dataframe(data)
except Exception as err:
msg = (
"Encountered an exception when converting data source "
"to a pandas DataFrame. See traceback above for details."
)
raise RuntimeError(msg) from err