from .base import _check_classifer_response_method
from .. import average_precision_score
from .. import precision_recall_curve
from ...utils import check_matplotlib_support
from ...base import is_classifier
class PrecisionRecallDisplay:
"""Precision Recall visualization.
It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve`
to create a visualizer. All parameters are stored as attributes.
Read more in the :ref:`User Guide <visualizations>`.
Parameters
-----------
precision : ndarray
Precision values.
recall : ndarray
Recall values.
average_precision : float
Average precision.
estimator_name : str
Name of estimator.
Attributes
----------
line_ : matplotlib Artist
Precision recall curve.
ax_ : matplotlib Axes
Axes with precision recall curve.
figure_ : matplotlib Figure
Figure containing the curve.
"""
def __init__(self, precision, recall, average_precision, estimator_name):
self.precision = precision
self.recall = recall
self.average_precision = average_precision
self.estimator_name = estimator_name
def plot(self, ax=None, name=None, **kwargs):
"""Plot visualization.
Extra keyword arguments will be passed to matplotlib's `plot`.
Parameters
----------
ax : Matplotlib Axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
name : str, default=None
Name of precision recall curve for labeling. If `None`, use the
name of the estimator.
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.
Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
Object that stores computed values.
"""
check_matplotlib_support("PrecisionRecallDisplay.plot")
import matplotlib.pyplot as plt
if ax is None:
fig, ax = plt.subplots()
name = self.estimator_name if name is None else name
line_kwargs = {
"label": "{} (AP = {:0.2f})".format(name,
self.average_precision),
"drawstyle": "steps-post"
}
line_kwargs.update(**kwargs)
self.line_, = ax.plot(self.recall, self.precision, **line_kwargs)
ax.set(xlabel="Recall", ylabel="Precision")
ax.legend(loc='lower left')
self.ax_ = ax
self.figure_ = ax.figure
return self
def plot_precision_recall_curve(estimator, X, y,
sample_weight=None, response_method="auto",
name=None, ax=None, **kwargs):
"""Plot Precision Recall Curve for binary classifers.
Extra keyword arguments will be passed to matplotlib's `plot`.
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
Parameters
----------
estimator : estimator instance
Trained classifier.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
y : array-like of shape (n_samples,)
Binary target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
response_method : {'predict_proba', 'decision_function', 'auto'}, \
default='auto'
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.
name : str, default=None
Name for labeling curve. If `None`, the name of the
estimator is used.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.
Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
Object that stores computed values.
"""
check_matplotlib_support("plot_precision_recall_curve")
classification_error = ("{} should be a binary classifer".format(
estimator.__class__.__name__))
if not is_classifier(estimator):
raise ValueError(classification_error)
prediction_method = _check_classifer_response_method(estimator,
response_method)
y_pred = prediction_method(X)
if y_pred.ndim != 1:
if y_pred.shape[1] != 2:
raise ValueError(classification_error)
else:
y_pred = y_pred[:, 1]
pos_label = estimator.classes_[1]
precision, recall, _ = precision_recall_curve(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)
average_precision = average_precision_score(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)
viz = PrecisionRecallDisplay(precision, recall, average_precision,
estimator.__class__.__name__)
return viz.plot(ax=ax, name=name, **kwargs)