import pytest
import numpy as np
from numpy.testing import assert_allclose
from numpy.testing import assert_array_equal
from sklearn.compose import make_column_transformer
from sklearn.datasets import make_classification
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC, SVR
from sklearn.metrics import confusion_matrix
from sklearn.metrics import plot_confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
pytestmark = pytest.mark.filterwarnings(
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
"matplotlib.*")
@pytest.fixture(scope="module")
def n_classes():
return 5
@pytest.fixture(scope="module")
def data(n_classes):
X, y = make_classification(n_samples=100, n_informative=5,
n_classes=n_classes, random_state=0)
return X, y
@pytest.fixture(scope="module")
def fitted_clf(data):
return SVC(kernel='linear', C=0.01).fit(*data)
@pytest.fixture(scope="module")
def y_pred(data, fitted_clf):
X, _ = data
return fitted_clf.predict(X)
def test_error_on_regressor(pyplot, data):
X, y = data
est = SVR().fit(X, y)
msg = "plot_confusion_matrix only supports classifiers"
with pytest.raises(ValueError, match=msg):
plot_confusion_matrix(est, X, y)
def test_error_on_invalid_option(pyplot, fitted_clf, data):
X, y = data
msg = (r"normalize must be one of \{'true', 'pred', 'all', "
r"None\}")
with pytest.raises(ValueError, match=msg):
plot_confusion_matrix(fitted_clf, X, y, normalize='invalid')
@pytest.mark.parametrize("with_labels", [True, False])
@pytest.mark.parametrize("with_display_labels", [True, False])
def test_plot_confusion_matrix_custom_labels(pyplot, data, y_pred, fitted_clf,
n_classes, with_labels,
with_display_labels):
X, y = data
ax = pyplot.gca()
labels = [2, 1, 0, 3, 4] if with_labels else None
display_labels = ['b', 'd', 'a', 'e', 'f'] if with_display_labels else None
cm = confusion_matrix(y, y_pred, labels=labels)
disp = plot_confusion_matrix(fitted_clf, X, y,
ax=ax, display_labels=display_labels,
labels=labels)
assert_allclose(disp.confusion_matrix, cm)
if with_display_labels:
expected_display_labels = display_labels
elif with_labels:
expected_display_labels = labels
else:
expected_display_labels = list(range(n_classes))
expected_display_labels_str = [str(name)
for name in expected_display_labels]
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
assert_array_equal(disp.display_labels, expected_display_labels)
assert_array_equal(x_ticks, expected_display_labels_str)
assert_array_equal(y_ticks, expected_display_labels_str)
@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None])
@pytest.mark.parametrize("include_values", [True, False])
def test_plot_confusion_matrix(pyplot, data, y_pred, n_classes, fitted_clf,
normalize, include_values):
X, y = data
ax = pyplot.gca()
cmap = 'plasma'
cm = confusion_matrix(y, y_pred)
disp = plot_confusion_matrix(fitted_clf, X, y,
normalize=normalize,
cmap=cmap, ax=ax,
include_values=include_values)
assert disp.ax_ == ax
if normalize == 'true':
cm = cm / cm.sum(axis=1, keepdims=True)
elif normalize == 'pred':
cm = cm / cm.sum(axis=0, keepdims=True)
elif normalize == 'all':
cm = cm / cm.sum()
assert_allclose(disp.confusion_matrix, cm)
import matplotlib as mpl
assert isinstance(disp.im_, mpl.image.AxesImage)
assert disp.im_.get_cmap().name == cmap
assert isinstance(disp.ax_, pyplot.Axes)
assert isinstance(disp.figure_, pyplot.Figure)
assert disp.ax_.get_ylabel() == "True label"
assert disp.ax_.get_xlabel() == "Predicted label"
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
expected_display_labels = list(range(n_classes))
expected_display_labels_str = [str(name)
for name in expected_display_labels]
assert_array_equal(disp.display_labels, expected_display_labels)
assert_array_equal(x_ticks, expected_display_labels_str)
assert_array_equal(y_ticks, expected_display_labels_str)
image_data = disp.im_.get_array().data
assert_allclose(image_data, cm)
if include_values:
assert disp.text_.shape == (n_classes, n_classes)
fmt = '.2g'
expected_text = np.array([format(v, fmt) for v in cm.ravel(order="C")])
text_text = np.array([
t.get_text() for t in disp.text_.ravel(order="C")])
assert_array_equal(expected_text, text_text)
else:
assert disp.text_ is None
def test_confusion_matrix_display(pyplot, data, fitted_clf, y_pred, n_classes):
X, y = data
cm = confusion_matrix(y, y_pred)
disp = plot_confusion_matrix(fitted_clf, X, y, normalize=None,
include_values=True, cmap='viridis',
xticks_rotation=45.0)
assert_allclose(disp.confusion_matrix, cm)
assert disp.text_.shape == (n_classes, n_classes)
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
assert_allclose(rotations, 45.0)
image_data = disp.im_.get_array().data
assert_allclose(image_data, cm)
disp.plot(cmap='plasma')
assert disp.im_.get_cmap().name == 'plasma'
disp.plot(include_values=False)
assert disp.text_ is None
disp.plot(xticks_rotation=90.0)
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
assert_allclose(rotations, 90.0)
disp.plot(values_format='e')
expected_text = np.array([format(v, 'e') for v in cm.ravel(order="C")])
text_text = np.array([
t.get_text() for t in disp.text_.ravel(order="C")])
assert_array_equal(expected_text, text_text)
def test_confusion_matrix_contrast(pyplot):
# make sure text color is appropriate depending on background
cm = np.eye(2) / 2
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
disp.plot(cmap=pyplot.cm.gray)
# diagonal text is black
assert_allclose(disp.text_[0, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
assert_allclose(disp.text_[1, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
# oof-diagonal text is white
assert_allclose(disp.text_[0, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
assert_allclose(disp.text_[1, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
disp.plot(cmap=pyplot.cm.gray_r)
# diagonal text is white
assert_allclose(disp.text_[0, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
assert_allclose(disp.text_[1, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
# oof-diagonal text is black
assert_allclose(disp.text_[0, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
assert_allclose(disp.text_[1, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
@pytest.mark.parametrize(
"clf", [LogisticRegression(),
make_pipeline(StandardScaler(), LogisticRegression()),
make_pipeline(make_column_transformer((StandardScaler(), [0, 1])),
LogisticRegression())])
def test_confusion_matrix_pipeline(pyplot, clf, data, n_classes):
X, y = data
with pytest.raises(NotFittedError):
plot_confusion_matrix(clf, X, y)
clf.fit(X, y)
y_pred = clf.predict(X)
disp = plot_confusion_matrix(clf, X, y)
cm = confusion_matrix(y, y_pred)
assert_allclose(disp.confusion_matrix, cm)
assert disp.text_.shape == (n_classes, n_classes)