from __future__ import division, print_function, absolute_import
import pytest

from math import sqrt, exp, sin, cos
from functools import lru_cache

from numpy.testing import (assert_warns, assert_,
import numpy as np
from numpy import finfo, power, nan, isclose

from scipy.optimize import zeros, newton, root_scalar

from scipy._lib._util import getargspec_no_self as _getargspec

# Import testing parameters
from scipy.optimize._tstutils import get_tests, functions as tstutils_functions, fstrings as tstutils_fstrings
from scipy._lib._numpy_compat import suppress_warnings

TOL = 4*np.finfo(float).eps  # tolerance

_FLOAT_EPS = finfo(float).eps

# A few test functions used frequently:
# # A simple quadratic, (x-1)^2 - 1
def f1(x):
    return x ** 2 - 2 * x - 1

def f1_1(x):
    return 2 * x - 2

def f1_2(x):
    return 2.0 + 0 * x

def f1_and_p_and_pp(x):
    return f1(x), f1_1(x), f1_2(x)

# Simple transcendental function
def f2(x):
    return exp(x) - cos(x)

def f2_1(x):
    return exp(x) + sin(x)

def f2_2(x):
    return exp(x) + cos(x)

# lru cached function
def f_lrucached(x):
    return x

class TestBasic(object):

    def run_check_by_name(self, name, smoothness=0, **kwargs):
        a = .5
        b = sqrt(3)
        xtol = 4*np.finfo(float).eps
        rtol = 4*np.finfo(float).eps
        for function, fname in zip(tstutils_functions, tstutils_fstrings):
            if smoothness > 0 and fname in ['f4', 'f5', 'f6']:
            r = root_scalar(function, method=name, bracket=[a, b], x0=a,
                            xtol=xtol, rtol=rtol, **kwargs)
            zero = r.root
            assert_allclose(zero, 1.0, atol=xtol, rtol=rtol,
                            err_msg='method %s, function %s' % (name, fname))

    def run_check(self, method, name):
        a = .5
        b = sqrt(3)
        xtol = 4 * _FLOAT_EPS
        rtol = 4 * _FLOAT_EPS
        for function, fname in zip(tstutils_functions, tstutils_fstrings):
            zero, r = method(function, a, b, xtol=xtol, rtol=rtol,
            assert_allclose(zero, 1.0, atol=xtol, rtol=rtol,
                            err_msg='method %s, function %s' % (name, fname))

    def run_check_lru_cached(self, method, name):
        # check that https://github.com/scipy/scipy/issues/10846 is fixed
        a = -1
        b = 1
        zero, r = method(f_lrucached, a, b, full_output=True)
        assert_allclose(zero, 0,
                        err_msg='method %s, function %s' % (name, 'f_lrucached'))

    def _run_one_test(self, tc, method, sig_args_keys=None,
                      sig_kwargs_keys=None, **kwargs):
        method_args = []
        for k in sig_args_keys or []:
            if k not in tc:
                # If a,b not present use x0, x1. Similarly for f and func
                k = {'a': 'x0', 'b': 'x1', 'func': 'f'}.get(k, k)

        method_kwargs = dict(**kwargs)
        method_kwargs.update({'full_output': True, 'disp': False})
        for k in sig_kwargs_keys or []:
            method_kwargs[k] = tc[k]

        root = tc.get('root')
        func_args = tc.get('args', ())

            r, rr = method(*method_args, args=func_args, **method_kwargs)
            return root, rr, tc
        except Exception:
            return root, zeros.RootResults(nan, -1, -1, zeros._EVALUEERR), tc

    def run_tests(self, tests, method, name,
                  xtol=4 * _FLOAT_EPS, rtol=4 * _FLOAT_EPS,
                  known_fail=None, **kwargs):
        r"""Run test-cases using the specified method and the supplied signature.

        Extract the arguments for the method call from the test case
        dictionary using the supplied keys for the method's signature."""
        # The methods have one of two base signatures:
        # (f, a, b, **kwargs)  # newton
        # (func, x0, **kwargs)  # bisect/brentq/...
        sig = _getargspec(method)  # ArgSpec with args, varargs, varkw, defaults
        nDefaults = len(sig[3])
        nRequired = len(sig[0]) - nDefaults
        sig_args_keys = sig[0][:nRequired]
        sig_kwargs_keys = []
        if name in ['secant', 'newton', 'halley']:
            if name in ['newton', 'halley']:
                if name in ['halley']:
            kwargs['tol'] = xtol
            kwargs['xtol'] = xtol
            kwargs['rtol'] = rtol

        results = [list(self._run_one_test(
            tc, method, sig_args_keys=sig_args_keys,
            sig_kwargs_keys=sig_kwargs_keys, **kwargs)) for tc in tests]
        # results= [[true root, full output, tc], ...]

        known_fail = known_fail or []
        notcvgd = [elt for elt in results if not elt[1].converged]
        notcvgd = [elt for elt in notcvgd if elt[-1]['ID'] not in known_fail]
        notcvged_IDS = [elt[-1]['ID'] for elt in notcvgd]
        assert_equal([len(notcvged_IDS), notcvged_IDS], [0, []])

        # The usable xtol and rtol depend on the test
        tols = {'xtol': 4 * _FLOAT_EPS, 'rtol': 4 * _FLOAT_EPS}
        rtol = tols['rtol']
        atol = tols.get('tol', tols['xtol'])

        cvgd = [elt for elt in results if elt[1].converged]
        approx = [elt[1].root for elt in cvgd]
        correct = [elt[0] for elt in cvgd]
        notclose = [[a] + elt for a, c, elt in zip(approx, correct, cvgd) if
                    not isclose(a, c, rtol=rtol, atol=atol)
                    and elt[-1]['ID'] not in known_fail]
        # Evaluate the function and see if is 0 at the purported root
        fvs = [tc['f'](aroot, *(tc['args'])) for aroot, c, fullout, tc in notclose]
        notclose = [[fv] + elt for fv, elt in zip(fvs, notclose) if fv != 0]
        assert_equal([notclose, len(notclose)], [[], 0])

    def run_collection(self, collection, method, name, smoothness=None,
                       xtol=4 * _FLOAT_EPS, rtol=4 * _FLOAT_EPS,
        r"""Run a collection of tests using the specified method.

        The name is used to determine some optional arguments."""
        tests = get_tests(collection, smoothness=smoothness)
        self.run_tests(tests, method, name, xtol=xtol, rtol=rtol,
                       known_fail=known_fail, **kwargs)

    def test_bisect(self):
        self.run_check(zeros.bisect, 'bisect')
        self.run_check_lru_cached(zeros.bisect, 'bisect')
        self.run_collection('aps', zeros.bisect, 'bisect', smoothness=1)

    def test_ridder(self):
        self.run_check(zeros.ridder, 'ridder')
        self.run_check_lru_cached(zeros.ridder, 'ridder')
        self.run_collection('aps', zeros.ridder, 'ridder', smoothness=1)

    def test_brentq(self):
        self.run_check(zeros.brentq, 'brentq')
        self.run_check_lru_cached(zeros.brentq, 'brentq')
        # Brentq/h needs a lower tolerance to be specified
        self.run_collection('aps', zeros.brentq, 'brentq', smoothness=1,
                            xtol=1e-14, rtol=1e-14)

    def test_brenth(self):
        self.run_check(zeros.brenth, 'brenth')
        self.run_check_lru_cached(zeros.brenth, 'brenth')
        self.run_collection('aps', zeros.brenth, 'brenth', smoothness=1,
                            xtol=1e-14, rtol=1e-14)

    def test_toms748(self):
        self.run_check(zeros.toms748, 'toms748')
        self.run_check_lru_cached(zeros.toms748, 'toms748')
        self.run_collection('aps', zeros.toms748, 'toms748', smoothness=1)

    def test_newton_collections(self):
        known_fail = ['aps.13.00']
        known_fail += ['aps.12.05', 'aps.12.17']  # fails under Windows Py27
        for collection in ['aps', 'complex']:
            self.run_collection(collection, zeros.newton, 'newton',
                                smoothness=2, known_fail=known_fail)

    def test_halley_collections(self):
        known_fail = ['aps.12.06', 'aps.12.07', 'aps.12.08', 'aps.12.09',
                      'aps.12.10', 'aps.12.11', 'aps.12.12', 'aps.12.13',
                      'aps.12.14', 'aps.12.15', 'aps.12.16', 'aps.12.17',
                      'aps.12.18', 'aps.13.00']
        for collection in ['aps', 'complex']:
            self.run_collection(collection, zeros.newton, 'halley',
                                smoothness=2, known_fail=known_fail)

    def f1(x):
        return x**2 - 2*x - 1  # == (x-1)**2 - 2

    def f1_1(x):
        return 2*x - 2

    def f1_2(x):
        return 2.0 + 0*x

    def f2(x):
        return exp(x) - cos(x)

    def f2_1(x):
        return exp(x) + sin(x)

    def f2_2(x):
        return exp(x) + cos(x)

    def test_newton(self):
        for f, f_1, f_2 in [(self.f1, self.f1_1, self.f1_2),
                            (self.f2, self.f2_1, self.f2_2)]:
            x = zeros.newton(f, 3, tol=1e-6)
            assert_allclose(f(x), 0, atol=1e-6)
            x = zeros.newton(f, 3, x1=5, tol=1e-6)  # secant, x0 and x1
            assert_allclose(f(x), 0, atol=1e-6)
            x = zeros.newton(f, 3, fprime=f_1, tol=1e-6)   # newton
            assert_allclose(f(x), 0, atol=1e-6)
            x = zeros.newton(f, 3, fprime=f_1, fprime2=f_2, tol=1e-6)  # halley
            assert_allclose(f(x), 0, atol=1e-6)

    def test_newton_by_name(self):
        r"""Invoke newton through root_scalar()"""
        for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
            r = root_scalar(f, method='newton', x0=3, fprime=f_1, xtol=1e-6)
            assert_allclose(f(r.root), 0, atol=1e-6)

    def test_secant_by_name(self):
        r"""Invoke secant through root_scalar()"""
        for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
            r = root_scalar(f, method='secant', x0=3, x1=2, xtol=1e-6)
            assert_allclose(f(r.root), 0, atol=1e-6)
            r = root_scalar(f, method='secant', x0=3, x1=5, xtol=1e-6)
            assert_allclose(f(r.root), 0, atol=1e-6)

    def test_halley_by_name(self):
        r"""Invoke halley through root_scalar()"""
        for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
            r = root_scalar(f, method='halley', x0=3,
                            fprime=f_1, fprime2=f_2, xtol=1e-6)
            assert_allclose(f(r.root), 0, atol=1e-6)

    def test_root_scalar_fail(self):
        with pytest.raises(ValueError):
            root_scalar(f1, method='secant', x0=3, xtol=1e-6)  # no x1
        with pytest.raises(ValueError):
            root_scalar(f1, method='newton', x0=3, xtol=1e-6)  # no fprime
        with pytest.raises(ValueError):
            root_scalar(f1, method='halley', fprime=f1_1, x0=3, xtol=1e-6)  # no fprime2
        with pytest.raises(ValueError):
            root_scalar(f1, method='halley', fprime2=f1_2, x0=3, xtol=1e-6)  # no fprime

    def test_array_newton(self):
        """test newton with array"""

        def f1(x, *a):
            b = a[0] + x * a[3]
            return a[1] - a[2] * (np.exp(b / a[5]) - 1.0) - b / a[4] - x

        def f1_1(x, *a):
            b = a[3] / a[5]
            return -a[2] * np.exp(a[0] / a[5] + x * b) * b - a[3] / a[4] - 1

        def f1_2(x, *a):
            b = a[3] / a[5]
            return -a[2] * np.exp(a[0] / a[5] + x * b) * b**2

        a0 = np.array([
            5.32725221, 5.48673747, 5.49539973,
            5.36387202, 4.80237316, 1.43764452,
            5.23063958, 5.46094772, 5.50512718,
        a1 = (np.sin(range(10)) + 1.0) * 7.0
        args = (a0, a1, 1e-09, 0.004, 10, 0.27456)
        x0 = [7.0] * 10
        x = zeros.newton(f1, x0, f1_1, args)
        x_expected = (
            6.17264965, 11.7702805, 12.2219954,
            7.11017681, 1.18151293, 0.143707955,
            4.31928228, 10.5419107, 12.7552490,
        assert_allclose(x, x_expected)
        # test halley's
        x = zeros.newton(f1, x0, f1_1, args, fprime2=f1_2)
        assert_allclose(x, x_expected)
        # test secant
        x = zeros.newton(f1, x0, args=args)
        assert_allclose(x, x_expected)

    def test_array_secant_active_zero_der(self):
        """test secant doesn't continue to iterate zero derivatives"""
