"""
Testing for Multi-layer Perceptron module (sklearn.neural_network)
"""
# Author: Issam H. Laradji
# License: BSD 3 clause
import pytest
import sys
import warnings
import re
import numpy as np
from numpy.testing import assert_almost_equal, assert_array_equal
from sklearn.datasets import load_digits, load_boston, load_iris
from sklearn.datasets import make_regression, make_multilabel_classification
from sklearn.exceptions import ConvergenceWarning
from io import StringIO
from sklearn.metrics import roc_auc_score
from sklearn.neural_network import MLPClassifier
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from scipy.sparse import csr_matrix
from sklearn.utils._testing import ignore_warnings
ACTIVATION_TYPES = ["identity", "logistic", "tanh", "relu"]
X_digits, y_digits = load_digits(n_class=3, return_X_y=True)
X_digits_multi = MinMaxScaler().fit_transform(X_digits[:200])
y_digits_multi = y_digits[:200]
X_digits, y_digits = load_digits(n_class=2, return_X_y=True)
X_digits_binary = MinMaxScaler().fit_transform(X_digits[:200])
y_digits_binary = y_digits[:200]
classification_datasets = [(X_digits_multi, y_digits_multi),
(X_digits_binary, y_digits_binary)]
boston = load_boston()
Xboston = StandardScaler().fit_transform(boston.data)[: 200]
yboston = boston.target[:200]
regression_datasets = [(Xboston, yboston)]
iris = load_iris()
X_iris = iris.data
y_iris = iris.target
def test_alpha():
# Test that larger alpha yields weights closer to zero
X = X_digits_binary[:100]
y = y_digits_binary[:100]
alpha_vectors = []
alpha_values = np.arange(2)
absolute_sum = lambda x: np.sum(np.abs(x))
for alpha in alpha_values:
mlp = MLPClassifier(hidden_layer_sizes=10, alpha=alpha, random_state=1)
with ignore_warnings(category=ConvergenceWarning):
mlp.fit(X, y)
alpha_vectors.append(np.array([absolute_sum(mlp.coefs_[0]),
absolute_sum(mlp.coefs_[1])]))
for i in range(len(alpha_values) - 1):
assert (alpha_vectors[i] > alpha_vectors[i + 1]).all()
def test_fit():
# Test that the algorithm solution is equal to a worked out example.
X = np.array([[0.6, 0.8, 0.7]])
y = np.array([0])
mlp = MLPClassifier(solver='sgd', learning_rate_init=0.1, alpha=0.1,
activation='logistic', random_state=1, max_iter=1,
hidden_layer_sizes=2, momentum=0)
# set weights
mlp.coefs_ = [0] * 2
mlp.intercepts_ = [0] * 2
mlp.n_outputs_ = 1
mlp.coefs_[0] = np.array([[0.1, 0.2], [0.3, 0.1], [0.5, 0]])
mlp.coefs_[1] = np.array([[0.1], [0.2]])
mlp.intercepts_[0] = np.array([0.1, 0.1])
mlp.intercepts_[1] = np.array([1.0])
mlp._coef_grads = [] * 2
mlp._intercept_grads = [] * 2
# Initialize parameters
mlp.n_iter_ = 0
mlp.learning_rate_ = 0.1
# Compute the number of layers
mlp.n_layers_ = 3
# Pre-allocate gradient matrices
mlp._coef_grads = [0] * (mlp.n_layers_ - 1)
mlp._intercept_grads = [0] * (mlp.n_layers_ - 1)
mlp.out_activation_ = 'logistic'
mlp.t_ = 0
mlp.best_loss_ = np.inf
mlp.loss_curve_ = []
mlp._no_improvement_count = 0
mlp._intercept_velocity = [np.zeros_like(intercepts) for
intercepts in
mlp.intercepts_]
mlp._coef_velocity = [np.zeros_like(coefs) for coefs in
mlp.coefs_]
mlp.partial_fit(X, y, classes=[0, 1])
# Manually worked out example
# h1 = g(X1 * W_i1 + b11) = g(0.6 * 0.1 + 0.8 * 0.3 + 0.7 * 0.5 + 0.1)
# = 0.679178699175393
# h2 = g(X2 * W_i2 + b12) = g(0.6 * 0.2 + 0.8 * 0.1 + 0.7 * 0 + 0.1)
# = 0.574442516811659
# o1 = g(h * W2 + b21) = g(0.679 * 0.1 + 0.574 * 0.2 + 1)
# = 0.7654329236196236
# d21 = -(0 - 0.765) = 0.765
# d11 = (1 - 0.679) * 0.679 * 0.765 * 0.1 = 0.01667
# d12 = (1 - 0.574) * 0.574 * 0.765 * 0.2 = 0.0374
# W1grad11 = X1 * d11 + alpha * W11 = 0.6 * 0.01667 + 0.1 * 0.1 = 0.0200
# W1grad11 = X1 * d12 + alpha * W12 = 0.6 * 0.0374 + 0.1 * 0.2 = 0.04244
# W1grad21 = X2 * d11 + alpha * W13 = 0.8 * 0.01667 + 0.1 * 0.3 = 0.043336
# W1grad22 = X2 * d12 + alpha * W14 = 0.8 * 0.0374 + 0.1 * 0.1 = 0.03992
# W1grad31 = X3 * d11 + alpha * W15 = 0.6 * 0.01667 + 0.1 * 0.5 = 0.060002
# W1grad32 = X3 * d12 + alpha * W16 = 0.6 * 0.0374 + 0.1 * 0 = 0.02244
# W2grad1 = h1 * d21 + alpha * W21 = 0.679 * 0.765 + 0.1 * 0.1 = 0.5294
# W2grad2 = h2 * d21 + alpha * W22 = 0.574 * 0.765 + 0.1 * 0.2 = 0.45911
# b1grad1 = d11 = 0.01667
# b1grad2 = d12 = 0.0374
# b2grad = d21 = 0.765
# W1 = W1 - eta * [W1grad11, .., W1grad32] = [[0.1, 0.2], [0.3, 0.1],
# [0.5, 0]] - 0.1 * [[0.0200, 0.04244], [0.043336, 0.03992],
# [0.060002, 0.02244]] = [[0.098, 0.195756], [0.2956664,
# 0.096008], [0.4939998, -0.002244]]
# W2 = W2 - eta * [W2grad1, W2grad2] = [[0.1], [0.2]] - 0.1 *
# [[0.5294], [0.45911]] = [[0.04706], [0.154089]]
# b1 = b1 - eta * [b1grad1, b1grad2] = 0.1 - 0.1 * [0.01667, 0.0374]
# = [0.098333, 0.09626]
# b2 = b2 - eta * b2grad = 1.0 - 0.1 * 0.765 = 0.9235
assert_almost_equal(mlp.coefs_[0], np.array([[0.098, 0.195756],
[0.2956664, 0.096008],
[0.4939998, -0.002244]]),
decimal=3)
assert_almost_equal(mlp.coefs_[1], np.array([[0.04706], [0.154089]]),
decimal=3)
assert_almost_equal(mlp.intercepts_[0],
np.array([0.098333, 0.09626]), decimal=3)
assert_almost_equal(mlp.intercepts_[1], np.array(0.9235), decimal=3)
# Testing output
# h1 = g(X1 * W_i1 + b11) = g(0.6 * 0.098 + 0.8 * 0.2956664 +
# 0.7 * 0.4939998 + 0.098333) = 0.677
# h2 = g(X2 * W_i2 + b12) = g(0.6 * 0.195756 + 0.8 * 0.096008 +
# 0.7 * -0.002244 + 0.09626) = 0.572
# o1 = h * W2 + b21 = 0.677 * 0.04706 +
# 0.572 * 0.154089 + 0.9235 = 1.043
# prob = sigmoid(o1) = 0.739
assert_almost_equal(mlp.predict_proba(X)[0, 1], 0.739, decimal=3)
def test_gradient():
# Test gradient.
# This makes sure that the activation functions and their derivatives
# are correct. The numerical and analytical computation of the gradient
# should be close.
for n_labels in [2, 3]:
n_samples = 5
n_features = 10
random_state = np.random.RandomState(seed=42)
X = random_state.rand(n_samples, n_features)
y = 1 + np.mod(np.arange(n_samples) + 1, n_labels)
Y = LabelBinarizer().fit_transform(y)
for activation in ACTIVATION_TYPES:
mlp = MLPClassifier(activation=activation, hidden_layer_sizes=10,
solver='lbfgs', alpha=1e-5,
learning_rate_init=0.2, max_iter=1,
random_state=1)
mlp.fit(X, y)
theta = np.hstack([l.ravel() for l in mlp.coefs_ +
mlp.intercepts_])
layer_units = ([X.shape[1]] + [mlp.hidden_layer_sizes] +
[mlp.n_outputs_])
activations = []
deltas = []
coef_grads = []
intercept_grads = []
activations.append(X)
for i in range(mlp.n_layers_ - 1):
activations.append(np.empty((X.shape[0],
layer_units[i + 1])))
deltas.append(np.empty((X.shape[0],
layer_units[i + 1])))
fan_in = layer_units[i]
fan_out = layer_units[i + 1]
coef_grads.append(np.empty((fan_in, fan_out)))
intercept_grads.append(np.empty(fan_out))
# analytically compute the gradients
def loss_grad_fun(t):
return mlp._loss_grad_lbfgs(t, X, Y, activations, deltas,
coef_grads, intercept_grads)
[value, grad] = loss_grad_fun(theta)
numgrad = np.zeros(np.size(theta))
n = np.size(theta, 0)
E = np.eye(n)
epsilon = 1e-5
# numerically compute the gradients
for i in range(n):
dtheta = E[:, i] * epsilon
numgrad[i] = ((loss_grad_fun(theta + dtheta)[0] -
loss_grad_fun(theta - dtheta)[0]) /
(epsilon * 2.0))
assert_almost_equal(numgrad, grad)
@pytest.mark.parametrize('X,y', classification_datasets)
def test_lbfgs_classification(X, y):
# Test lbfgs on classification.
# It should achieve a score higher than 0.95 for the binary and multi-class
# versions of the digits dataset.
X_train = X[:150]
y_train = y[:150]
X_test = X[150:]
expected_shape_dtype = (X_test.shape[0], y_train.dtype.kind)
for activation in ACTIVATION_TYPES:
mlp = MLPClassifier(solver='lbfgs', hidden_layer_sizes=50,
max_iter=150, shuffle=True, random_state=1,
activation=activation)
mlp.fit(X_train, y_train)
y_predict = mlp.predict(X_test)
assert mlp.score(X_train, y_train) > 0.95
assert ((y_predict.shape[0], y_predict.dtype.kind) ==
expected_shape_dtype)
@pytest.mark.parametrize('X,y', regression_datasets)
def test_lbfgs_regression(X, y):
# Test lbfgs on the boston dataset, a regression problems.
for activation in ACTIVATION_TYPES:
mlp = MLPRegressor(solver='lbfgs', hidden_layer_sizes=50,
max_iter=150, shuffle=True, random_state=1,
activation=activation)
mlp.fit(X, y)
if activation == 'identity':
assert mlp.score(X, y) > 0.84
else:
# Non linear models perform much better than linear bottleneck:
assert mlp.score(X, y) > 0.95
@pytest.mark.parametrize('X,y', classification_datasets)
def test_lbfgs_classification_maxfun(X, y):
# Test lbfgs parameter max_fun.
# It should independently limit the number of iterations for lbfgs.
max_fun = 10
# classification tests
for activation in ACTIVATION_TYPES:
mlp = MLPClassifier(solver='lbfgs', hidden_layer_sizes=50,
max_iter=150, max_fun=max_fun, shuffle=True,
random_state=1, activation=activation)
with pytest.warns(ConvergenceWarning):
mlp.fit(X, y)
assert max_fun >= mlp.n_iter_
@pytest.mark.parametrize('X,y', regression_datasets)
def test_lbfgs_regression_maxfun(X, y):
# Test lbfgs parameter max_fun.
# It should independently limit the number of iterations for lbfgs.
max_fun = 10
# regression tests
for activation in ACTIVATION_TYPES:
mlp = MLPRegressor(solver='lbfgs', hidden_layer_sizes=50,
max_iter=150, max_fun=max_fun, shuffle=True,
random_state=1, activation=activation)
with pytest.warns(ConvergenceWarning):
mlp.fit(X, y)
assert max_fun >= mlp.n_iter_
mlp.max_fun = -1
with pytest.raises(ValueError):
mlp.fit(X, y)
def test_learning_rate_warmstart():
# Tests that warm_start reuse past solutions.
X = [[3, 2], [1, 6], [5, 6], [-2, -4]]
y = [1, 1, 1, 0]
for learning_rate in ["invscaling", "constant"]:
mlp = MLPClassifier(solver='sgd', hidden_layer_sizes=4,
learning_rate=learning_rate, max_iter=1,
power_t=0.25, warm_start=True)
with ignore_warnings(category=ConvergenceWarning):
mlp.fit(X, y)
prev_eta = mlp._optimizer.learning_rate
mlp.fit(X, y)
post_eta = mlp._optimizer.learning_rate
if learning_rate == 'constant':
assert prev_eta == post_eta
elif learning_rate == 'invscaling':
assert (mlp.learning_rate_init / pow(8 + 1, mlp.power_t) ==
post_eta)
def test_multilabel_classification():
# Test that multi-label classification works as expected.
# test fit method
X, y = make_multilabel_classification(n_samples=50, random_state=0,
return_indicator=True)
mlp = MLPClassifier(solver='lbfgs', hidden_layer_sizes=50, alpha=1e-5,
max_iter=150, random_state=0, activation='logistic',
learning_rate_init=0.2)
mlp.fit(X, y)
assert mlp.score(X, y) > 0.97
# test partial fit method
mlp = MLPClassifier(solver='sgd', hidden_layer_sizes=50, max_iter=150,
random_state=0, activation='logistic', alpha=1e-5,
learning_rate_init=0.2)
for i in range(100):
mlp.partial_fit(X, y, classes=[0, 1, 2, 3, 4])
assert mlp.score(X, y) > 0.9
# Make sure early stopping still work now that spliting is stratified by
# default (it is disabled for multilabel classification)
mlp = MLPClassifier(early_stopping=True)
mlp.fit(X, y).predict(X)
Loading ...