"""
Multiclass and multilabel classification strategies
===================================================
This module implements multiclass learning algorithms:
- one-vs-the-rest / one-vs-all
- one-vs-one
- error correcting output codes
The estimators provided in this module are meta-estimators: they require a base
estimator to be provided in their constructor. For example, it is possible to
use these estimators to turn a binary classifier or a regressor into a
multiclass classifier. It is also possible to use these estimators with
multiclass estimators in the hope that their accuracy or runtime performance
improves.
All classifiers in scikit-learn implement multiclass classification; you
only need to use this module if you want to experiment with custom multiclass
strategies.
The one-vs-the-rest meta-classifier also implements a `predict_proba` method,
so long as such a method is implemented by the base classifier. This method
returns probabilities of class membership in both the single label and
multilabel case. Note that in the multilabel case, probabilities are the
marginal probability that a given sample falls in the given class. As such, in
the multilabel case the sum of these probabilities over all possible labels
for a given sample *will not* sum to unity, as they do in the single label
case.
"""
# Author: Mathieu Blondel <mathieu@mblondel.org>
# Author: Hamzeh Alsalhi <93hamsal@gmail.com>
#
# License: BSD 3 clause
import array
import numpy as np
import warnings
import scipy.sparse as sp
from .base import BaseEstimator, ClassifierMixin, clone, is_classifier
from .base import MetaEstimatorMixin, is_regressor
from .preprocessing import LabelBinarizer
from .metrics.pairwise import euclidean_distances
from .utils import check_random_state
from .utils.validation import _num_samples
from .utils.validation import check_consistent_length
from .utils.validation import check_is_fitted
from .utils import deprecated
from .externals.joblib import Parallel
from .externals.joblib import delayed
__all__ = [
"OneVsRestClassifier",
"OneVsOneClassifier",
"OutputCodeClassifier",
]
def _fit_binary(estimator, X, y, classes=None):
"""Fit a single binary estimator."""
unique_y = np.unique(y)
if len(unique_y) == 1:
if classes is not None:
if y[0] == -1:
c = 0
else:
c = y[0]
warnings.warn("Label %s is present in all training examples." %
str(classes[c]))
estimator = _ConstantPredictor().fit(X, unique_y)
else:
estimator = clone(estimator)
estimator.fit(X, y)
return estimator
def _predict_binary(estimator, X):
"""Make predictions using a single binary estimator."""
if is_regressor(estimator):
return estimator.predict(X)
try:
score = np.ravel(estimator.decision_function(X))
except (AttributeError, NotImplementedError):
# probabilities of the positive class
score = estimator.predict_proba(X)[:, 1]
return score
def _check_estimator(estimator):
"""Make sure that an estimator implements the necessary methods."""
if (not hasattr(estimator, "decision_function") and
not hasattr(estimator, "predict_proba")):
raise ValueError("The base estimator should implement "
"decision_function or predict_proba!")
@deprecated("fit_ovr is deprecated and will be removed in 0.18."
"Use the OneVsRestClassifier instead.")
def fit_ovr(estimator, X, y, n_jobs=1):
"""Fit a one-vs-the-rest strategy.
Parameters
----------
estimator : estimator object
An estimator object implementing `fit` and one of `decision_function`
or `predict_proba`.
X : (sparse) array-like, shape = [n_samples, n_features]
Data.
y : (sparse) array-like, shape = [n_samples] or [n_samples, n_classes]
Multi-class targets. An indicator matrix turns on multilabel
classification.
Returns
-------
estimators : list of estimators object
The list of fitted estimator.
lb : fitted LabelBinarizer
"""
ovr = OneVsRestClassifier(estimator, n_jobs=n_jobs).fit(X, y)
return ovr.estimators_, ovr.label_binarizer_
@deprecated("predict_ovr is deprecated and will be removed in 0.18."
"Use the OneVsRestClassifier instead.")
def predict_ovr(estimators, label_binarizer, X):
"""Predict multi-class targets using the one vs rest strategy.
Parameters
----------
estimators : list of `n_classes` estimators, Estimators used for
predictions. The list must be homogeneous with respect to the type of
estimators. fit_ovr supplies this list as part of its output.
label_binarizer : LabelBinarizer object, Object used to transform
multiclass labels to binary labels and vice-versa. fit_ovr supplies
this object as part of its output.
X : (sparse) array-like, shape = [n_samples, n_features]
Data.
Returns
-------
y : (sparse) array-like, shape = [n_samples] or [n_samples, n_classes].
Predicted multi-class targets.
"""
e_types = set([type(e) for e in estimators if not
isinstance(e, _ConstantPredictor)])
if len(e_types) > 1:
raise ValueError("List of estimators must contain estimators of the"
" same type but contains types {0}".format(e_types))
ovr = OneVsRestClassifier(clone(estimators[0]))
ovr.estimators_ = estimators
ovr.label_binarizer_ = label_binarizer
return ovr.predict(X)
@deprecated("predict_proba_ovr is deprecated and will be removed in 0.18."
"Use the OneVsRestClassifier instead.")
def predict_proba_ovr(estimators, X, is_multilabel):
e_types = set([type(e) for e in estimators if not
isinstance(e, _ConstantPredictor)])
if len(e_types) > 1:
raise ValueError("List of estimators must contain estimators of the"
" same type but contains types {0}".format(e_types))
Y = np.array([e.predict_proba(X)[:, 1] for e in estimators]).T
if not is_multilabel:
# Then, probabilities should be normalized to 1.
Y /= np.sum(Y, axis=1)[:, np.newaxis]
return Y
class _ConstantPredictor(BaseEstimator):
def fit(self, X, y):
self.y_ = y
return self
def predict(self, X):
check_is_fitted(self, 'y_')
return np.repeat(self.y_, X.shape[0])
def decision_function(self, X):
check_is_fitted(self, 'y_')
return np.repeat(self.y_, X.shape[0])
def predict_proba(self, X):
check_is_fitted(self, 'y_')
return np.repeat([np.hstack([1 - self.y_, self.y_])],
X.shape[0], axis=0)
class OneVsRestClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
"""One-vs-the-rest (OvR) multiclass/multilabel strategy
Also known as one-vs-all, this strategy consists in fitting one classifier
per class. For each classifier, the class is fitted against all the other
classes. In addition to its computational efficiency (only `n_classes`
classifiers are needed), one advantage of this approach is its
interpretability. Since each class is represented by one and one classifier
only, it is possible to gain knowledge about the class by inspecting its
corresponding classifier. This is the most commonly used strategy for
multiclass classification and is a fair default choice.
This strategy can also be used for multilabel learning, where a classifier
is used to predict multiple labels for instance, by fitting on a 2-d matrix
in which cell [i, j] is 1 if sample i has label j and 0 otherwise.
In the multilabel learning literature, OvR is also known as the binary
relevance method.
Read more in the :ref:`User Guide <ovr_classification>`.
Parameters
----------
estimator : estimator object
An estimator object implementing `fit` and one of `decision_function`
or `predict_proba`.
n_jobs : int, optional, default: 1
The number of jobs to use for the computation. If -1 all CPUs are used.
If 1 is given, no parallel computing code is used at all, which is
useful for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are
used. Thus for n_jobs = -2, all CPUs but one are used.
Attributes
----------
estimators_ : list of `n_classes` estimators
Estimators used for predictions.
classes_ : array, shape = [`n_classes`]
Class labels.
label_binarizer_ : LabelBinarizer object
Object used to transform multiclass labels to binary labels and
vice-versa.
multilabel_ : boolean
Whether a OneVsRestClassifier is a multilabel classifier.
"""
def __init__(self, estimator, n_jobs=1):
self.estimator = estimator
self.n_jobs = n_jobs
def fit(self, X, y):
"""Fit underlying estimators.
Parameters
----------
X : (sparse) array-like, shape = [n_samples, n_features]
Data.
y : (sparse) array-like, shape = [n_samples] or [n_samples, n_classes]
Multi-class targets. An indicator matrix turns on multilabel
classification.
Returns
-------
self
"""
# A sparse LabelBinarizer, with sparse_output=True, has been shown to
# outpreform or match a dense label binarizer in all cases and has also
# resulted in less or equal memory consumption in the fit_ovr function
# overall.
self.label_binarizer_ = LabelBinarizer(sparse_output=True)
Y = self.label_binarizer_.fit_transform(y)
Y = Y.tocsc()
columns = (col.toarray().ravel() for col in Y.T)
# In cases where individual estimators are very fast to train setting
# n_jobs > 1 in can results in slower performance due to the overhead
# of spawning threads. See joblib issue #112.
self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(_fit_binary)(
self.estimator, X, column, classes=[
"not %s" % self.label_binarizer_.classes_[i],
self.label_binarizer_.classes_[i]])
for i, column in enumerate(columns))
return self
def predict(self, X):
"""Predict multi-class targets using underlying estimators.
Parameters
----------
X : (sparse) array-like, shape = [n_samples, n_features]
Data.
Returns
-------
y : (sparse) array-like, shape = [n_samples] or [n_samples, n_classes].
Predicted multi-class targets.
"""
check_is_fitted(self, 'estimators_')
if (hasattr(self.estimators_[0], "decision_function") and
is_classifier(self.estimators_[0])):
thresh = 0
else:
thresh = .5
n_samples = _num_samples(X)
if self.label_binarizer_.y_type_ == "multiclass":
maxima = np.empty(n_samples, dtype=float)
maxima.fill(-np.inf)
argmaxima = np.zeros(n_samples, dtype=int)
for i, e in enumerate(self.estimators_):
pred = _predict_binary(e, X)
np.maximum(maxima, pred, out=maxima)
argmaxima[maxima == pred] = i
return self.label_binarizer_.classes_[np.array(argmaxima.T)]
else:
indices = array.array('i')
indptr = array.array('i', [0])
for e in self.estimators_:
indices.extend(np.where(_predict_binary(e, X) > thresh)[0])
indptr.append(len(indices))
data = np.ones(len(indices), dtype=int)
indicator = sp.csc_matrix((data, indices, indptr),
shape=(n_samples, len(self.estimators_)))
return self.label_binarizer_.inverse_transform(indicator)
def predict_proba(self, X):
"""Probability estimates.
The returned estimates for all classes are ordered by label of classes.
Note that in the multilabel case, each sample can have any number of
labels. This returns the marginal probability that the given sample has
the label in question. For example, it is entirely consistent that two
labels both have a 90% probability of applying to a given sample.
In the single label multiclass case, the rows of the returned matrix
sum to 1.
Parameters
Loading ...