"""Multi-layer Perceptron
"""
# Authors: Issam H. Laradji <issam.laradji@gmail.com>
# Andreas Mueller
# Jiyuan Qian
# License: BSD 3 clause
import numpy as np
from abc import ABCMeta, abstractmethod
import warnings
import scipy.optimize
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
from ..base import is_classifier
from ._base import ACTIVATIONS, DERIVATIVES, LOSS_FUNCTIONS
from ._stochastic_optimizers import SGDOptimizer, AdamOptimizer
from ..model_selection import train_test_split
from ..preprocessing import LabelBinarizer
from ..utils import gen_batches, check_random_state
from ..utils import shuffle
from ..utils import _safe_indexing
from ..utils import check_array, column_or_1d
from ..exceptions import ConvergenceWarning
from ..utils.extmath import safe_sparse_dot
from ..utils.validation import check_is_fitted, _deprecate_positional_args
from ..utils.multiclass import _check_partial_fit_first_call, unique_labels
from ..utils.multiclass import type_of_target
from ..utils.optimize import _check_optimize_result
_STOCHASTIC_SOLVERS = ['sgd', 'adam']
def _pack(coefs_, intercepts_):
"""Pack the parameters into a single vector."""
return np.hstack([l.ravel() for l in coefs_ + intercepts_])
class BaseMultilayerPerceptron(BaseEstimator, metaclass=ABCMeta):
"""Base class for MLP classification and regression.
Warning: This class should not be used directly.
Use derived classes instead.
.. versionadded:: 0.18
"""
@abstractmethod
def __init__(self, hidden_layer_sizes, activation, solver,
alpha, batch_size, learning_rate, learning_rate_init, power_t,
max_iter, loss, shuffle, random_state, tol, verbose,
warm_start, momentum, nesterovs_momentum, early_stopping,
validation_fraction, beta_1, beta_2, epsilon,
n_iter_no_change, max_fun):
self.activation = activation
self.solver = solver
self.alpha = alpha
self.batch_size = batch_size
self.learning_rate = learning_rate
self.learning_rate_init = learning_rate_init
self.power_t = power_t
self.max_iter = max_iter
self.loss = loss
self.hidden_layer_sizes = hidden_layer_sizes
self.shuffle = shuffle
self.random_state = random_state
self.tol = tol
self.verbose = verbose
self.warm_start = warm_start
self.momentum = momentum
self.nesterovs_momentum = nesterovs_momentum
self.early_stopping = early_stopping
self.validation_fraction = validation_fraction
self.beta_1 = beta_1
self.beta_2 = beta_2
self.epsilon = epsilon
self.n_iter_no_change = n_iter_no_change
self.max_fun = max_fun
def _unpack(self, packed_parameters):
"""Extract the coefficients and intercepts from packed_parameters."""
for i in range(self.n_layers_ - 1):
start, end, shape = self._coef_indptr[i]
self.coefs_[i] = np.reshape(packed_parameters[start:end], shape)
start, end = self._intercept_indptr[i]
self.intercepts_[i] = packed_parameters[start:end]
def _forward_pass(self, activations):
"""Perform a forward pass on the network by computing the values
of the neurons in the hidden layers and the output layer.
Parameters
----------
activations : list, length = n_layers - 1
The ith element of the list holds the values of the ith layer.
"""
hidden_activation = ACTIVATIONS[self.activation]
# Iterate over the hidden layers
for i in range(self.n_layers_ - 1):
activations[i + 1] = safe_sparse_dot(activations[i],
self.coefs_[i])
activations[i + 1] += self.intercepts_[i]
# For the hidden layers
if (i + 1) != (self.n_layers_ - 1):
activations[i + 1] = hidden_activation(activations[i + 1])
# For the last layer
output_activation = ACTIVATIONS[self.out_activation_]
activations[i + 1] = output_activation(activations[i + 1])
return activations
def _compute_loss_grad(self, layer, n_samples, activations, deltas,
coef_grads, intercept_grads):
"""Compute the gradient of loss with respect to coefs and intercept for
specified layer.
This function does backpropagation for the specified one layer.
"""
coef_grads[layer] = safe_sparse_dot(activations[layer].T,
deltas[layer])
coef_grads[layer] += (self.alpha * self.coefs_[layer])
coef_grads[layer] /= n_samples
intercept_grads[layer] = np.mean(deltas[layer], 0)
return coef_grads, intercept_grads
def _loss_grad_lbfgs(self, packed_coef_inter, X, y, activations, deltas,
coef_grads, intercept_grads):
"""Compute the MLP loss function and its corresponding derivatives
with respect to the different parameters given in the initialization.
Returned gradients are packed in a single vector so it can be used
in lbfgs
Parameters
----------
packed_coef_inter : ndarray
A vector comprising the flattened coefficients and intercepts.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input data.
y : ndarray of shape (n_samples,)
The target values.
activations : list, length = n_layers - 1
The ith element of the list holds the values of the ith layer.
deltas : list, length = n_layers - 1
The ith element of the list holds the difference between the
activations of the i + 1 layer and the backpropagated error.
More specifically, deltas are gradients of loss with respect to z
in each layer, where z = wx + b is the value of a particular layer
before passing through the activation function
coef_grads : list, length = n_layers - 1
The ith element contains the amount of change used to update the
coefficient parameters of the ith layer in an iteration.
intercept_grads : list, length = n_layers - 1
The ith element contains the amount of change used to update the
intercept parameters of the ith layer in an iteration.
Returns
-------
loss : float
grad : array-like, shape (number of nodes of all layers,)
"""
self._unpack(packed_coef_inter)
loss, coef_grads, intercept_grads = self._backprop(
X, y, activations, deltas, coef_grads, intercept_grads)
grad = _pack(coef_grads, intercept_grads)
return loss, grad
def _backprop(self, X, y, activations, deltas, coef_grads,
intercept_grads):
"""Compute the MLP loss function and its corresponding derivatives
with respect to each parameter: weights and bias vectors.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input data.
y : ndarray of shape (n_samples,)
The target values.
activations : list, length = n_layers - 1
The ith element of the list holds the values of the ith layer.
deltas : list, length = n_layers - 1
The ith element of the list holds the difference between the
activations of the i + 1 layer and the backpropagated error.
More specifically, deltas are gradients of loss with respect to z
in each layer, where z = wx + b is the value of a particular layer
before passing through the activation function
coef_grads : list, length = n_layers - 1
The ith element contains the amount of change used to update the
coefficient parameters of the ith layer in an iteration.
intercept_grads : list, length = n_layers - 1
The ith element contains the amount of change used to update the
intercept parameters of the ith layer in an iteration.
Returns
-------
loss : float
coef_grads : list, length = n_layers - 1
intercept_grads : list, length = n_layers - 1
"""
n_samples = X.shape[0]
# Forward propagate
activations = self._forward_pass(activations)
# Get loss
loss_func_name = self.loss
if loss_func_name == 'log_loss' and self.out_activation_ == 'logistic':
loss_func_name = 'binary_log_loss'
loss = LOSS_FUNCTIONS[loss_func_name](y, activations[-1])
# Add L2 regularization term to loss
values = np.sum(
np.array([np.dot(s.ravel(), s.ravel()) for s in self.coefs_]))
loss += (0.5 * self.alpha) * values / n_samples
# Backward propagate
last = self.n_layers_ - 2
# The calculation of delta[last] here works with following
# combinations of output activation and loss function:
# sigmoid and binary cross entropy, softmax and categorical cross
# entropy, and identity with squared loss
deltas[last] = activations[-1] - y
# Compute gradient for the last layer
coef_grads, intercept_grads = self._compute_loss_grad(
last, n_samples, activations, deltas, coef_grads, intercept_grads)
# Iterate over the hidden layers
for i in range(self.n_layers_ - 2, 0, -1):
deltas[i - 1] = safe_sparse_dot(deltas[i], self.coefs_[i].T)
inplace_derivative = DERIVATIVES[self.activation]
inplace_derivative(activations[i], deltas[i - 1])
coef_grads, intercept_grads = self._compute_loss_grad(
i - 1, n_samples, activations, deltas, coef_grads,
intercept_grads)
return loss, coef_grads, intercept_grads
def _initialize(self, y, layer_units):
# set all attributes, allocate weights etc for first call
# Initialize parameters
self.n_iter_ = 0
self.t_ = 0
self.n_outputs_ = y.shape[1]
# Compute the number of layers
self.n_layers_ = len(layer_units)
# Output for regression
if not is_classifier(self):
self.out_activation_ = 'identity'
# Output for multi class
elif self._label_binarizer.y_type_ == 'multiclass':
self.out_activation_ = 'softmax'
# Output for binary class and multi-label
else:
self.out_activation_ = 'logistic'
# Initialize coefficient and intercept layers
self.coefs_ = []
self.intercepts_ = []
for i in range(self.n_layers_ - 1):
coef_init, intercept_init = self._init_coef(layer_units[i],
layer_units[i + 1])
self.coefs_.append(coef_init)
self.intercepts_.append(intercept_init)
if self.solver in _STOCHASTIC_SOLVERS:
self.loss_curve_ = []
self._no_improvement_count = 0
if self.early_stopping:
self.validation_scores_ = []
self.best_validation_score_ = -np.inf
else:
self.best_loss_ = np.inf
def _init_coef(self, fan_in, fan_out):
# Use the initialization method recommended by
# Glorot et al.
factor = 6.
if self.activation == 'logistic':
factor = 2.
init_bound = np.sqrt(factor / (fan_in + fan_out))
# Generate weights and bias:
coef_init = self._random_state.uniform(-init_bound, init_bound,
(fan_in, fan_out))
intercept_init = self._random_state.uniform(-init_bound, init_bound,
fan_out)
return coef_init, intercept_init
def _fit(self, X, y, incremental=False):
# Make sure self.hidden_layer_sizes is a list
hidden_layer_sizes = self.hidden_layer_sizes
if not hasattr(hidden_layer_sizes, "__iter__"):
hidden_layer_sizes = [hidden_layer_sizes]
hidden_layer_sizes = list(hidden_layer_sizes)
# Validate input parameters.
self._validate_hyperparameters()
if np.any(np.array(hidden_layer_sizes) <= 0):
raise ValueError("hidden_layer_sizes must be > 0, got %s." %
hidden_layer_sizes)
X, y = self._validate_input(X, y, incremental)
n_samples, n_features = X.shape
# Ensure y is 2D
if y.ndim == 1:
y = y.reshape((-1, 1))
self.n_outputs_ = y.shape[1]
layer_units = ([n_features] + hidden_layer_sizes +
[self.n_outputs_])
# check random state
self._random_state = check_random_state(self.random_state)
if not hasattr(self, 'coefs_') or (not self.warm_start and not
incremental):
# First time training the model
self._initialize(y, layer_units)
Loading ...