"""
Tests for discrete models
Notes
-----
DECIMAL_3 is used because it seems that there is a loss of precision
in the Stata *.dta -> *.csv output, NOT the estimator for the Poisson
tests.
"""
# pylint: disable-msg=E1101
from statsmodels.compat.pandas import assert_index_equal
import os
import warnings
import numpy as np
from numpy.testing import (assert_, assert_raises, assert_almost_equal,
assert_equal, assert_array_equal, assert_allclose,
assert_array_less)
import pandas as pd
import pytest
from scipy import stats
from statsmodels.discrete.discrete_model import (Logit, Probit, MNLogit,
Poisson, NegativeBinomial,
CountModel,
GeneralizedPoisson,
NegativeBinomialP)
from statsmodels.discrete.discrete_margins import _iscount, _isdummy
import statsmodels.api as sm
import statsmodels.formula.api as smf
from .results.results_discrete import Spector, DiscreteL1, RandHIE, Anes
from statsmodels.tools.sm_exceptions import (PerfectSeparationError,
SpecificationWarning,
ConvergenceWarning)
from scipy.stats import nbinom
try:
import cvxopt # noqa:F401
has_cvxopt = True
except ImportError:
has_cvxopt = False
DECIMAL_14 = 14
DECIMAL_10 = 10
DECIMAL_9 = 9
DECIMAL_4 = 4
DECIMAL_3 = 3
DECIMAL_2 = 2
DECIMAL_1 = 1
DECIMAL_0 = 0
class CheckModelMixin(object):
# Assertions about the Model object, as opposed to the Results
# Assumes that mixed-in class implements:
# res1
def test_fit_regularized_invalid_method(self):
# GH#5224 check we get ValueError when passing invalid "method" arg
model = self.res1.model
with pytest.raises(ValueError, match=r'is not supported, use either'):
model.fit_regularized(method="foo")
class CheckModelResults(CheckModelMixin):
"""
res2 should be the test results from RModelWrap
or the results as defined in model_results_data
"""
def test_params(self):
assert_almost_equal(self.res1.params, self.res2.params, DECIMAL_4)
def test_conf_int(self):
assert_allclose(self.res1.conf_int(), self.res2.conf_int, rtol=8e-5)
def test_zstat(self):
assert_almost_equal(self.res1.tvalues, self.res2.z, DECIMAL_4)
def test_pvalues(self):
assert_almost_equal(self.res1.pvalues, self.res2.pvalues, DECIMAL_4)
def test_cov_params(self):
if not hasattr(self.res2, "cov_params"):
raise pytest.skip("TODO: implement res2.cov_params")
assert_almost_equal(self.res1.cov_params(),
self.res2.cov_params,
DECIMAL_4)
def test_llf(self):
assert_almost_equal(self.res1.llf, self.res2.llf, DECIMAL_4)
def test_llnull(self):
assert_almost_equal(self.res1.llnull, self.res2.llnull, DECIMAL_4)
def test_llr(self):
assert_almost_equal(self.res1.llr, self.res2.llr, DECIMAL_3)
def test_llr_pvalue(self):
assert_almost_equal(self.res1.llr_pvalue,
self.res2.llr_pvalue,
DECIMAL_4)
@pytest.mark.xfail(reason="Test has not been implemented for this class.",
strict=True, raises=NotImplementedError)
def test_normalized_cov_params(self):
raise NotImplementedError
def test_bse(self):
assert_almost_equal(self.res1.bse, self.res2.bse, DECIMAL_4)
def test_dof(self):
assert_equal(self.res1.df_model, self.res2.df_model)
assert_equal(self.res1.df_resid, self.res2.df_resid)
def test_aic(self):
assert_almost_equal(self.res1.aic, self.res2.aic, DECIMAL_3)
def test_bic(self):
assert_almost_equal(self.res1.bic, self.res2.bic, DECIMAL_3)
def test_predict(self):
assert_almost_equal(self.res1.model.predict(self.res1.params),
self.res2.phat, DECIMAL_4)
def test_predict_xb(self):
assert_almost_equal(self.res1.model.predict(self.res1.params,
linear=True),
self.res2.yhat, DECIMAL_4)
def test_loglikeobs(self):
#basic cross check
llobssum = self.res1.model.loglikeobs(self.res1.params).sum()
assert_almost_equal(llobssum, self.res1.llf, DECIMAL_14)
def test_jac(self):
#basic cross check
jacsum = self.res1.model.score_obs(self.res1.params).sum(0)
score = self.res1.model.score(self.res1.params)
assert_almost_equal(jacsum, score, DECIMAL_9) #Poisson has low precision ?
class CheckBinaryResults(CheckModelResults):
def test_pred_table(self):
assert_array_equal(self.res1.pred_table(), self.res2.pred_table)
def test_resid_dev(self):
assert_almost_equal(self.res1.resid_dev, self.res2.resid_dev,
DECIMAL_4)
def test_resid_generalized(self):
assert_almost_equal(self.res1.resid_generalized,
self.res2.resid_generalized, DECIMAL_4)
@pytest.mark.smoke
def test_resid_response(self):
self.res1.resid_response
class CheckMargEff(object):
"""
Test marginal effects (margeff) and its options
"""
def test_nodummy_dydxoverall(self):
me = self.res1.get_margeff()
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_dydx, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_dydx_se, DECIMAL_4)
me_frame = me.summary_frame()
eff = me_frame["dy/dx"].values
assert_allclose(eff, me.margeff, rtol=1e-13)
assert_equal(me_frame.shape, (me.margeff.size, 6))
def test_nodummy_dydxmean(self):
me = self.res1.get_margeff(at='mean')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_dydxmean, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_dydxmean_se, DECIMAL_4)
def test_nodummy_dydxmedian(self):
me = self.res1.get_margeff(at='median')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_dydxmedian, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_dydxmedian_se, DECIMAL_4)
def test_nodummy_dydxzero(self):
me = self.res1.get_margeff(at='zero')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_dydxzero, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_dydxzero, DECIMAL_4)
def test_nodummy_dyexoverall(self):
me = self.res1.get_margeff(method='dyex')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_dyex, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_dyex_se, DECIMAL_4)
def test_nodummy_dyexmean(self):
me = self.res1.get_margeff(at='mean', method='dyex')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_dyexmean, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_dyexmean_se, DECIMAL_4)
def test_nodummy_dyexmedian(self):
me = self.res1.get_margeff(at='median', method='dyex')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_dyexmedian, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_dyexmedian_se, DECIMAL_4)
def test_nodummy_dyexzero(self):
me = self.res1.get_margeff(at='zero', method='dyex')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_dyexzero, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_dyexzero_se, DECIMAL_4)
def test_nodummy_eydxoverall(self):
me = self.res1.get_margeff(method='eydx')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_eydx, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_eydx_se, DECIMAL_4)
def test_nodummy_eydxmean(self):
me = self.res1.get_margeff(at='mean', method='eydx')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_eydxmean, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_eydxmean_se, DECIMAL_4)
def test_nodummy_eydxmedian(self):
me = self.res1.get_margeff(at='median', method='eydx')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_eydxmedian, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_eydxmedian_se, DECIMAL_4)
def test_nodummy_eydxzero(self):
me = self.res1.get_margeff(at='zero', method='eydx')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_eydxzero, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_eydxzero_se, DECIMAL_4)
def test_nodummy_eyexoverall(self):
me = self.res1.get_margeff(method='eyex')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_eyex, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_eyex_se, DECIMAL_4)
def test_nodummy_eyexmean(self):
me = self.res1.get_margeff(at='mean', method='eyex')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_eyexmean, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_eyexmean_se, DECIMAL_4)
def test_nodummy_eyexmedian(self):
me = self.res1.get_margeff(at='median', method='eyex')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_eyexmedian, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_eyexmedian_se, DECIMAL_4)
def test_nodummy_eyexzero(self):
me = self.res1.get_margeff(at='zero', method='eyex')
assert_almost_equal(me.margeff,
self.res2.margeff_nodummy_eyexzero, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_nodummy_eyexzero_se, DECIMAL_4)
def test_dummy_dydxoverall(self):
me = self.res1.get_margeff(dummy=True)
assert_almost_equal(me.margeff,
self.res2.margeff_dummy_dydx, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_dummy_dydx_se, DECIMAL_4)
def test_dummy_dydxmean(self):
me = self.res1.get_margeff(at='mean', dummy=True)
assert_almost_equal(me.margeff,
self.res2.margeff_dummy_dydxmean, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_dummy_dydxmean_se, DECIMAL_4)
def test_dummy_eydxoverall(self):
me = self.res1.get_margeff(method='eydx', dummy=True)
assert_almost_equal(me.margeff,
self.res2.margeff_dummy_eydx, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_dummy_eydx_se, DECIMAL_4)
def test_dummy_eydxmean(self):
me = self.res1.get_margeff(at='mean', method='eydx', dummy=True)
assert_almost_equal(me.margeff,
self.res2.margeff_dummy_eydxmean, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_dummy_eydxmean_se, DECIMAL_4)
def test_count_dydxoverall(self):
me = self.res1.get_margeff(count=True)
assert_almost_equal(me.margeff,
self.res2.margeff_count_dydx, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_count_dydx_se, DECIMAL_4)
def test_count_dydxmean(self):
me = self.res1.get_margeff(count=True, at='mean')
assert_almost_equal(me.margeff,
self.res2.margeff_count_dydxmean, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_count_dydxmean_se, DECIMAL_4)
def test_count_dummy_dydxoverall(self):
me = self.res1.get_margeff(count=True, dummy=True)
assert_almost_equal(me.margeff,
self.res2.margeff_count_dummy_dydxoverall, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_count_dummy_dydxoverall_se, DECIMAL_4)
def test_count_dummy_dydxmean(self):
me = self.res1.get_margeff(count=True, dummy=True, at='mean')
assert_almost_equal(me.margeff,
self.res2.margeff_count_dummy_dydxmean, DECIMAL_4)
assert_almost_equal(me.margeff_se,
self.res2.margeff_count_dummy_dydxmean_se, DECIMAL_4)
class TestProbitNewton(CheckBinaryResults):
@classmethod
Loading ...