Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

alkaline-ml / statsmodels   python

Repository URL to install this package:

Version: 0.11.1 

/ gam / gam_cross_validation / cross_validators.py

# -*- coding: utf-8 -*-
"""
Cross-validation iterators for GAM

Author: Luca Puggini

"""

from abc import ABCMeta, abstractmethod
from statsmodels.compat.python import with_metaclass
import numpy as np


class BaseCrossValidator(with_metaclass(ABCMeta)):
    """
    The BaseCrossValidator class is a base class for all the iterators that
    split the data in train and test as for example KFolds or LeavePOut
    """
    def __init__(self):
        pass

    @abstractmethod
    def split(self):
        pass


class KFold(BaseCrossValidator):
    """
    K-Folds cross validation iterator:
    Provides train/test indexes to split data in train test sets

    Parameters
    ----------
    k: int
        number of folds
    shuffle : bool
        If true, then the index is shuffled before splitting into train and
        test indices.

    Notes
    -----
    All folds except for last fold have size trunc(n/k), the last fold has
    the remainder.
    """

    def __init__(self, k_folds, shuffle=False):
        self.nobs = None
        self.k_folds = k_folds
        self.shuffle = shuffle

    def split(self, X, y=None, label=None):
        """yield index split into train and test sets
        """
        # TODO: X and y are redundant, we only need nobs

        nobs = X.shape[0]
        index = np.array(range(nobs))

        if self.shuffle:
            np.random.shuffle(index)

        folds = np.array_split(index, self.k_folds)
        for fold in folds:
            test_index = np.zeros(nobs, dtype=np.bool_)
            test_index[fold] = True
            train_index = np.logical_not(test_index)
            yield train_index, test_index