import unittest
import sys

import numpy as np
import scipy.sparse as sp
import joblib

from io import StringIO

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils import deprecated
from sklearn.utils._testing import (assert_raises_regex,
                                   assert_warns, assert_raises,
from sklearn.utils.estimator_checks import check_estimator, _NotAnArray
from sklearn.utils.estimator_checks \
    import check_class_weight_balanced_linear_classifier
from sklearn.utils.estimator_checks import set_random_state
from sklearn.utils.estimator_checks import _set_checking_parameters
from sklearn.utils.estimator_checks import check_estimators_unfitted
from sklearn.utils.estimator_checks import check_fit_score_takes_y
from sklearn.utils.estimator_checks import check_no_attributes_set_in_init
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.estimator_checks import check_outlier_corruption
from sklearn.utils.fixes import _parse_version
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LinearRegression, SGDClassifier
from sklearn.mixture import GaussianMixture
from sklearn.cluster import MiniBatchKMeans
from sklearn.decomposition import NMF
from sklearn.linear_model import MultiTaskElasticNet, LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils.validation import check_X_y, check_array

class CorrectNotFittedError(ValueError):
    """Exception class to raise if estimator is used before fitting.

    Like NotFittedError, it inherits from ValueError, but not from
    AttributeError. Used for testing only.

class BaseBadClassifier(ClassifierMixin, BaseEstimator):
    def fit(self, X, y):
        return self

    def predict(self, X):
        return np.ones(X.shape[0])

class ChangesDict(BaseEstimator):
    def __init__(self, key=0):
        self.key = key

    def fit(self, X, y=None):
        X, y = check_X_y(X, y)
        return self

    def predict(self, X):
        X = check_array(X)
        self.key = 1000
        return np.ones(X.shape[0])

class SetsWrongAttribute(BaseEstimator):
    def __init__(self, acceptable_key=0):
        self.acceptable_key = acceptable_key

    def fit(self, X, y=None):
        self.wrong_attribute = 0
        X, y = check_X_y(X, y)
        return self

class ChangesWrongAttribute(BaseEstimator):
    def __init__(self, wrong_attribute=0):
        self.wrong_attribute = wrong_attribute

    def fit(self, X, y=None):
        self.wrong_attribute = 1
        X, y = check_X_y(X, y)
        return self

class ChangesUnderscoreAttribute(BaseEstimator):
    def fit(self, X, y=None):
        self._good_attribute = 1
        X, y = check_X_y(X, y)
        return self

class RaisesErrorInSetParams(BaseEstimator):
    def __init__(self, p=0):
        self.p = p

    def set_params(self, **kwargs):
        if 'p' in kwargs:
            p = kwargs.pop('p')
            if p < 0:
                raise ValueError("p can't be less than 0")
            self.p = p
        return super().set_params(**kwargs)

    def fit(self, X, y=None):
        X, y = check_X_y(X, y)
        return self

class ModifiesValueInsteadOfRaisingError(BaseEstimator):
    def __init__(self, p=0):
        self.p = p

    def set_params(self, **kwargs):
        if 'p' in kwargs:
            p = kwargs.pop('p')
            if p < 0:
                p = 0
            self.p = p
        return super().set_params(**kwargs)

    def fit(self, X, y=None):
        X, y = check_X_y(X, y)
        return self

class ModifiesAnotherValue(BaseEstimator):
    def __init__(self, a=0, b='method1'):
        self.a = a
        self.b = b

    def set_params(self, **kwargs):
        if 'a' in kwargs:
            a = kwargs.pop('a')
            self.a = a
            if a is None:
                self.b = 'method2'
        return super().set_params(**kwargs)

    def fit(self, X, y=None):
        X, y = check_X_y(X, y)
        return self

class NoCheckinPredict(BaseBadClassifier):
    def fit(self, X, y):
        X, y = check_X_y(X, y)
        return self

class NoSparseClassifier(BaseBadClassifier):
    def fit(self, X, y):
        X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
        if sp.issparse(X):
            raise ValueError("Nonsensical Error")
        return self

    def predict(self, X):
        X = check_array(X)
        return np.ones(X.shape[0])

class CorrectNotFittedErrorClassifier(BaseBadClassifier):
    def fit(self, X, y):
        X, y = check_X_y(X, y)
        self.coef_ = np.ones(X.shape[1])
        return self

    def predict(self, X):
        X = check_array(X)
        return np.ones(X.shape[0])

class NoSampleWeightPandasSeriesType(BaseEstimator):
    def fit(self, X, y, sample_weight=None):
        # Convert data
        X, y = check_X_y(X, y,
                         accept_sparse=("csr", "csc"),
        # Function is only called after we verify that pandas is installed
        from pandas import Series
        if isinstance(sample_weight, Series):
            raise ValueError("Estimator does not accept 'sample_weight'"
                             "of type pandas.Series")
        return self

    def predict(self, X):
        X = check_array(X)
        return np.ones(X.shape[0])

class BadBalancedWeightsClassifier(BaseBadClassifier):
    def __init__(self, class_weight=None):
        self.class_weight = class_weight

    def fit(self, X, y):
        from sklearn.preprocessing import LabelEncoder
        from sklearn.utils import compute_class_weight

        label_encoder = LabelEncoder().fit(y)
        classes = label_encoder.classes_
        class_weight = compute_class_weight(self.class_weight, classes, y)

        # Intentionally modify the balanced class_weight
        # to simulate a bug and raise an exception
        if self.class_weight == "balanced":
            class_weight += 1.

        # Simply assigning coef_ to the class_weight
        self.coef_ = class_weight
        return self

class BadTransformerWithoutMixin(BaseEstimator):
    def fit(self, X, y=None):
        X = check_array(X)
        return self

    def transform(self, X):
        X = check_array(X)
        return X

class NotInvariantPredict(BaseEstimator):
    def fit(self, X, y):
        # Convert data
        X, y = check_X_y(X, y,
                         accept_sparse=("csr", "csc"),
        return self

    def predict(self, X):
        # return 1 if X has more than one element else return 0
        X = check_array(X)
        if X.shape[0] > 1:
            return np.ones(X.shape[0])
        return np.zeros(X.shape[0])

class LargeSparseNotSupportedClassifier(BaseEstimator):
    def fit(self, X, y):
        X, y = check_X_y(X, y,
                         accept_sparse=("csr", "csc", "coo"),
        if sp.issparse(X):
            if X.getformat() == "coo":
                if X.row.dtype == "int64" or X.col.dtype == "int64":
                    raise ValueError(
                        "Estimator doesn't support 64-bit indices")
            elif X.getformat() in ["csc", "csr"]:
                if X.indices.dtype == "int64" or X.indptr.dtype == "int64":
                    raise ValueError(
                        "Estimator doesn't support 64-bit indices")

        return self

class SparseTransformer(BaseEstimator):
    def fit(self, X, y=None):
        self.X_shape_ = check_array(X).shape
        return self

    def fit_transform(self, X, y=None):
        return self.fit(X, y).transform(X)

    def transform(self, X):
        X = check_array(X)
        if X.shape[1] != self.X_shape_[1]:
            raise ValueError('Bad number of features')
        return sp.csr_matrix(X)

class UntaggedBinaryClassifier(DecisionTreeClassifier):
    # Toy classifier that only supports binary classification, will fail tests.
    def fit(self, X, y, sample_weight=None):
        super().fit(X, y, sample_weight)
        if np.all(self.n_classes_ > 2):
            raise ValueError('Only 2 classes are supported')
        return self

class TaggedBinaryClassifier(UntaggedBinaryClassifier):
    # Toy classifier that only supports binary classification.
    def _more_tags(self):
        return {'binary_only': True}

class RequiresPositiveYRegressor(LinearRegression):

    def fit(self, X, y):
        X, y = check_X_y(X, y, multi_output=True)
        if (y <= 0).any():
            raise ValueError('negative y values not supported!')
        return super().fit(X, y)

    def _more_tags(self):
        return {"requires_positive_y": True}

def test_not_an_array_array_function():
    np_version = _parse_version(np.__version__)
    if np_version < (1, 17):
        raise SkipTest("array_function protocol not supported in numpy <1.17")
    not_array = _NotAnArray(np.ones(10))
    msg = "Don't want to call array_function sum!"
    assert_raises_regex(TypeError, msg, np.sum, not_array)
    # always returns True
    assert np.may_share_memory(not_array, None)

def test_check_fit_score_takes_y_works_on_deprecated_fit():
    # Tests that check_fit_score_takes_y works on a class with
    # a deprecated fit method

    class TestEstimatorWithDeprecatedFitMethod(BaseEstimator):
        @deprecated("Deprecated for the purpose of testing "
        def fit(self, X, y):
            return self

    check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod())

def test_check_estimator():
    # tests that the estimator actually fails on "bad" estimators.
    # not a complete test of all checks, which are very extensive.

    # check that we have a set_params and can clone
    msg = "it does not implement a 'get_params' methods"
    assert_raises_regex(TypeError, msg, check_estimator, object)
    assert_raises_regex(TypeError, msg, check_estimator, object())
    # check that values returned by get_params match set_params
    msg = "get_params result does not match what was passed to set_params"
    assert_raises_regex(AssertionError, msg, check_estimator,
    assert_warns(UserWarning, check_estimator, RaisesErrorInSetParams())
    assert_raises_regex(AssertionError, msg, check_estimator,
    # check that we have a fit method
    msg = "object has no attribute 'fit'"
    assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator)
    assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator())
    # check that fit does input validation
    msg = "ValueError not raised"
    assert_raises_regex(AssertionError, msg, check_estimator,
    assert_raises_regex(AssertionError, msg, check_estimator,
    # check that sample_weights in fit accepts pandas.Series type
        from pandas import Series  # noqa
        msg = ("Estimator NoSampleWeightPandasSeriesType raises error if "
               "'sample_weight' parameter is of type pandas.Series")
            ValueError, msg, check_estimator, NoSampleWeightPandasSeriesType)
    except ImportError:
    # check that predict does input validation (doesn't accept dicts in input)
    msg = "Estimator doesn't check for NaN and inf in predict"
    assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict)
    assert_raises_regex(AssertionError, msg, check_estimator,
    # check that estimator state does not change
    # at transform/predict/predict_proba time
    msg = 'Estimator changes __dict__ during predict'
    assert_raises_regex(AssertionError, msg, check_estimator, ChangesDict)
    # check that `fit` only changes attribures that
    # are private (start with an _ or end with a _).
    msg = ('Estimator ChangesWrongAttribute should not change or mutate  '
           'the parameter wrong_attribute from 0 to 1 during fit.')
    assert_raises_regex(AssertionError, msg,
                        check_estimator, ChangesWrongAttribute)
    # check that `fit` doesn't add any public attribute
    msg = (r'Estimator adds public attribute\(s\) during the fit method.'
           ' Estimators are only allowed to add private attributes'
           ' either started with _ or ended'
           ' with _ but wrong_attribute added')
    assert_raises_regex(AssertionError, msg,
                        check_estimator, SetsWrongAttribute)
    # check for invariant method
    name = NotInvariantPredict.__name__
    method = 'predict'
    msg = ("{method} of {name} is not invariant when applied "
           "to a subset.").format(method=method, name=name)
    assert_raises_regex(AssertionError, msg,
                        check_estimator, NotInvariantPredict)
    # check for sparse matrix input handling
    name = NoSparseClassifier.__name__
    msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name
    # the check for sparse input handling prints to the stdout,
    # instead of raising an error, so as not to remove the original traceback.
    # that means we need to jump through some hoops to catch it.
    old_stdout = sys.stdout
    string_buffer = StringIO()
    sys.stdout = string_buffer
        sys.stdout = old_stdout
    assert msg in string_buffer.getvalue()

    # Large indices test on bad estimator
    msg = ('Estimator LargeSparseNotSupportedClassifier doesn\'t seem to '
           r'support \S{3}_64 matrix, and is not failing gracefully.*')
    assert_raises_regex(AssertionError, msg, check_estimator,

    # does error on binary_only untagged estimator
    msg = 'Only 2 classes are supported'
    assert_raises_regex(ValueError, msg, check_estimator,

    # non-regression test for estimators transforming to sparse data

    # doesn't error on actual estimator

    # doesn't error on binary_only tagged estimator

    # Check regressor with requires_positive_y estimator tag
    msg = 'negative y values not supported!'
    assert_raises_regex(ValueError, msg, check_estimator,

def test_check_outlier_corruption():
    # should raise AssertionError
    decision = np.array([0., 1., 1.5, 2.])
    assert_raises(AssertionError, check_outlier_corruption, 1, 2, decision)
    # should pass
    decision = np.array([0., 1., 1., 2.])
    check_outlier_corruption(1, 2, decision)

def test_check_estimator_transformer_no_mixin():
    # check that TransformerMixin is not required for transformer tests to run
    assert_raises_regex(AttributeError, '.*fit_transform.*',
                        check_estimator, BadTransformerWithoutMixin())

def test_check_estimator_clones():
    # check that check_estimator doesn't modify the estimator it receives
    from sklearn.datasets import load_iris
    iris = load_iris()

    for Estimator in [GaussianMixture, LinearRegression,
                      RandomForestClassifier, NMF, SGDClassifier,
        with ignore_warnings(category=FutureWarning):
            # when 'est = SGDClassifier()'
            est = Estimator()
            # without fitting
            old_hash = joblib.hash(est)
        assert old_hash == joblib.hash(est)

        with ignore_warnings(category=FutureWarning):
            # when 'est = SGDClassifier()'
            est = Estimator()
            # with fitting
            est.fit(iris.data + 10, iris.target)
            old_hash = joblib.hash(est)
        assert old_hash == joblib.hash(est)

def test_check_estimators_unfitted():
    # check that a ValueError/AttributeError is raised when calling predict
    # on an unfitted estimator
    msg = "NotFittedError not raised by predict"
    assert_raises_regex(AssertionError, msg, check_estimators_unfitted,
                        "estimator", NoSparseClassifier())

    # check that CorrectNotFittedError inherit from either ValueError
    # or AttributeError
    check_estimators_unfitted("estimator", CorrectNotFittedErrorClassifier())

def test_check_no_attributes_set_in_init():
    class NonConformantEstimatorPrivateSet(BaseEstimator):
        def __init__(self):
            self.you_should_not_set_this_ = None

    class NonConformantEstimatorNoParamSet(BaseEstimator):
        def __init__(self, you_should_set_this_=None):

                        "Estimator estimator_name should not set any"
                        " attribute apart from parameters during init."
                        r" Found attributes \['you_should_not_set_this_'\].",
                        "Estimator estimator_name should store all "
                        "parameters as an attribute during init. "
                        "Did not find attributes "

def test_check_estimator_pairwise():
    # check that check_estimator() works on estimator with _pairwise
    # kernel or metric

    # test precomputed kernel
    est = SVC(kernel='precomputed')

    # test precomputed metric
    est = KNeighborsRegressor(metric='precomputed')

def test_check_estimator_required_parameters_skip():
    class MyEstimator(BaseEstimator):
        _required_parameters = ["special_parameter"]

        def __init__(self, special_parameter):
            self.special_parameter = special_parameter

    assert_raises_regex(SkipTest, r"Can't instantiate estimator MyEstimator "
                                  r"which requires parameters "
                                  check_estimator, MyEstimator)

def run_tests_without_pytest():
    """Runs the tests in this file without using pytest.
    main_module = sys.modules['__main__']
    test_functions = [getattr(main_module, name) for name in dir(main_module)
                      if name.startswith('test_')]
    test_cases = [unittest.FunctionTestCase(fn) for fn in test_functions]
    suite = unittest.TestSuite()
    runner = unittest.TextTestRunner()

def test_check_class_weight_balanced_linear_classifier():
    # check that ill-computed balanced weights raises an exception
                        "Classifier estimator_name is not computing"
                        " class_weight=balanced properly.",

if __name__ == '__main__':
    # This module is run as a script to check that we have no dependency on
    # pytest for estimator checks.