Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

alkaline-ml / statsmodels   python

Repository URL to install this package:

Version: 0.11.1 

/ sandbox / nonparametric / tests / ex_gam_am_new.py

# -*- coding: utf-8 -*-
"""Example for gam.AdditiveModel and PolynomialSmoother

This example was written as a test case.
The data generating process is chosen so the parameters are well identified
and estimated.

Created on Fri Nov 04 13:45:43 2011

Author: Josef Perktold

"""
from statsmodels.compat.python import lrange

import numpy as np

from statsmodels.sandbox.gam import AdditiveModel
from statsmodels.regression.linear_model import OLS

np.random.seed(8765993)
#seed is chosen for nice result, not randomly
#other seeds are pretty off in the prediction

#DGP: simple polynomial
order = 3
sigma_noise = 0.5
nobs = 1000  #1000 #with 1000, OLS and Additivemodel agree in params at 2 decimals
lb, ub = -3.5, 4#2.5
x1 = np.linspace(lb, ub, nobs)
x2 = np.sin(2*x1)
x = np.column_stack((x1/x1.max()*2, x2))
exog = (x[:,:,None]**np.arange(order+1)[None, None, :]).reshape(nobs, -1)
idx = lrange((order+1)*2)
del idx[order+1]
exog_reduced = exog[:,idx]  #remove duplicate constant
y_true = exog.sum(1) / 2.
z = y_true #alias check
d = x
y = y_true + sigma_noise * np.random.randn(nobs)

example = 1

if example == 1:
    m = AdditiveModel(d)
    m.fit(y)

    y_pred = m.results.predict(d)


for ss in m.smoothers:
    print(ss.params)

res_ols = OLS(y, exog_reduced).fit()
print(res_ols.params)

#from numpy.testing import assert_almost_equal
#assert_almost_equal(y_pred, res_ols.fittedvalues, 3)

if example > 0:
    import matplotlib.pyplot as plt

    plt.figure()
    plt.plot(exog)

    y_pred = m.results.mu# + m.results.alpha #m.results.predict(d)
    plt.figure()
    plt.subplot(2,2,1)
    plt.plot(y, '.', alpha=0.25)
    plt.plot(y_true, 'k-', label='true')

    plt.plot(res_ols.fittedvalues, 'g-', label='OLS', lw=2, alpha=-.7)
    plt.plot(y_pred, 'r-', label='AM')
    plt.legend(loc='upper left')
    plt.title('gam.AdditiveModel')

    counter = 2
    for ii, xx in zip(['z', 'x1', 'x2'], [z, x[:,0], x[:,1]]):
        sortidx = np.argsort(xx)
        #plt.figure()
        plt.subplot(2, 2, counter)
        plt.plot(xx[sortidx], y[sortidx], '.', alpha=0.25)
        plt.plot(xx[sortidx], y_true[sortidx], 'k.', label='true', lw=2)
        plt.plot(xx[sortidx], y_pred[sortidx], 'r.', label='AM')
        plt.legend(loc='upper left')
        plt.title('gam.AdditiveModel ' + ii)
        counter += 1

    plt.show()