Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
# -*- coding: utf-8 -*-
"""


Created on Sun May 09 22:35:21 2010
Author: josef-pktd
License: BSD

todo:
change moment calculation, (currently uses default _ppf method - I think)
>>> lognormalg.moment(4)
Warning: The algorithm does not converge.  Roundoff error is detected
  in the extrapolation table.  It is assumed that the requested tolerance
  cannot be achieved, and that the returned result (if full_output = 1) is
  the best which can be obtained.
array(2981.0032380193438)
"""


import numpy as np
from numpy.testing import assert_almost_equal
from scipy import stats, special
from scikits.statsmodels.sandbox.distributions.extras import (
    lognormalg, squarenormalg, absnormalg, negsquarenormalg, squaretg)


#some patches to scipy.stats.distributions so tests work and pass

#patch frozen distributions with a name
stats.distributions.rv_frozen.name = property(lambda self: self.dist.name)

#patch f distribution, correct skew and maybe kurtosis
def f_stats(self, dfn, dfd):
    arr, where, inf, sqrt, nan = np.array, np.where, np.inf, np.sqrt, np.nan
    v2 = arr(dfd*1.0)
    v1 = arr(dfn*1.0)
    mu = where (v2 > 2, v2 / arr(v2 - 2), inf)
    mu2 = 2*v2*v2*(v2+v1-2)/(v1*(v2-2)**2 * (v2-4))
    mu2 = where(v2 > 4, mu2, inf)
    #g1 = 2*(v2+2*v1-2)/(v2-6)*sqrt((2*v2-4)/(v1*(v2+v1-2)))
    g1 = 2*(v2+2*v1-2.)/(v2-6.)*np.sqrt(2*(v2-4.)/(v1*(v2+v1-2.)))
    g1 = where(v2 > 6, g1, nan)
    #g2 = 3/(2*v2-16)*(8+g1*g1*(v2-6))
    g2 = 3/(2.*v2-16)*(8+g1*g1*(v2-6.))
    g2 = where(v2 > 8, g2, nan)
    return mu, mu2, g1, g2

stats.distributions.f_gen._stats = f_stats

#correct kurtosis by subtracting 3 (Fisher)
#after this it matches halfnorm for arg close to zero
def foldnorm_stats(self, c):
    arr, where, inf, sqrt, nan = np.array, np.where, np.inf, np.sqrt, np.nan
    exp = np.exp
    pi = np.pi

    fac = special.erf(c/sqrt(2))
    mu = sqrt(2.0/pi)*exp(-0.5*c*c)+c*fac
    mu2 = c*c + 1 - mu*mu
    c2 = c*c
    g1 = sqrt(2/pi)*exp(-1.5*c2)*(4-pi*exp(c2)*(2*c2+1.0))
    g1 += 2*c*fac*(6*exp(-c2) + 3*sqrt(2*pi)*c*exp(-c2/2.0)*fac + \
                   pi*c*(fac*fac-1))
    g1 /= pi*mu2**1.5

    g2 = c2*c2+6*c2+3+6*(c2+1)*mu*mu - 3*mu**4
    g2 -= 4*exp(-c2/2.0)*mu*(sqrt(2.0/pi)*(c2+2)+c*(c2+3)*exp(c2/2.0)*fac)
    g2 /= mu2**2.0
    g2 -= 3.
    return mu, mu2, g1, g2

stats.distributions.foldnorm_gen._stats = foldnorm_stats


#-----------------------------

DECIMAL = 5

class Test_Transf2(object):

    def __init__(self):
        self.dist_equivalents = [
            #transf, stats.lognorm(1))
            (lognormalg, stats.lognorm(1)),
            #transf2
            (squarenormalg, stats.chi2(1)),
            (absnormalg, stats.halfnorm),
            (absnormalg, stats.foldnorm(1e-5)),  #try frozen
            #(negsquarenormalg, 1-stats.chi2),  # won't work as distribution
            (squaretg(10), stats.f(1, 10))]      #try both frozen


        l,s = 0.0, 1.0
        self.ppfq = [0.1,0.5,0.9]
        self.xx = [0.95,1.0,1.1]
        self.nxx = [-0.95,-1.0,-1.1]

    def test_equivalent(self):
        xx, ppfq = self.xx, self.ppfq
        for d1,d2 in self.dist_equivalents:
##            print d1.name
            assert_almost_equal(d1.cdf(xx), d2.cdf(xx), err_msg='cdf'+d1.name)
            assert_almost_equal(d1.pdf(xx), d2.pdf(xx),
                                err_msg='pdf '+d1.name+d2.name)
            assert_almost_equal(d1.sf(xx), d2.sf(xx),
                                err_msg='sf '+d1.name+d2.name)
            assert_almost_equal(d1.ppf(ppfq), d2.ppf(ppfq),
                                err_msg='ppq '+d1.name+d2.name)
            assert_almost_equal(d1.isf(ppfq), d2.isf(ppfq),
                                err_msg='isf '+d1.name+d2.name)
            self.d1 = d1
            self.d2 = d2
##            print d1, d2
##            print d1.moment(3)
##            print d2.moment(3)
            #work around bug#1293
            if hasattr(d2, 'dist'):
                d2mom = d2.dist.moment(3, *d2.args)
            else:
                d2mom = d2.moment(3)
            assert_almost_equal(d1.moment(3), d2mom,
                                DECIMAL,
                                err_msg='moment '+d1.name+d2.name)
            s1 = d1.stats(moments='mvsk')
            s2 = d2.stats(moments='mvsk')
            #stats(moments='k') prints warning for lognormalg
            assert_almost_equal(s1[:2], s2[:2],
                                err_msg='stats '+d1.name+d2.name)
            assert_almost_equal(s1[2:], s2[2:],
                                decimal=2, #lognorm for kurtosis
                                err_msg='stats '+d1.name+d2.name)

    def test_equivalent_negsq(self):
        '''special case negsquarenormalg

        negsquarenormalg.cdf(x) == stats.chi2(1).cdf(-x), for x<=0
        '''
        xx, nxx, ppfq = self.xx, self.nxx, self.ppfq
        d1,d2 = (negsquarenormalg, stats.chi2(1))
        #print d1.name
        assert_almost_equal(d1.cdf(nxx), 1-d2.cdf(xx), err_msg='cdf'+d1.name)
        assert_almost_equal(d1.pdf(nxx), d2.pdf(xx))
        assert_almost_equal(d1.sf(nxx), 1-d2.sf(xx))
        assert_almost_equal(d1.ppf(ppfq), -d2.ppf(ppfq)[::-1])
        assert_almost_equal(d1.isf(ppfq), -d2.isf(ppfq)[::-1])
        assert_almost_equal(d1.moment(3), -d2.moment(3))
        ch2oddneg = [v*(-1)**(i+1) for i,v in
                     enumerate(d2.stats(moments='mvsk'))]
        assert_almost_equal(d1.stats(moments='mvsk'), ch2oddneg,
                            err_msg='stats '+d1.name+d2.name)


if __name__ == '__main__':
    tt = Test_Transf2()
    tt.test_equivalent()
    tt.test_equivalent_negsq()

    debug = 0
    if debug:
        print negsquarenormalg.ppf([0.1,0.5,0.9])
        print stats.chi2.ppf([0.1,0.5,0.9],1)
        print negsquarenormalg.a
        print negsquarenormalg.b

        print absnormalg.stats( moments='mvsk')
        print stats.foldnorm(1e-10).stats( moments='mvsk')
        print stats.halfnorm.stats( moments='mvsk')