Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

alkaline-ml / scikit-learn   python

Repository URL to install this package:

/ datasets / tests / test_openml.py

"""Test the openml loader.
"""
import gzip
import json
import numpy as np
import os
import re
import scipy.sparse
import sklearn
import pytest

from sklearn import config_context
from sklearn.datasets import fetch_openml
from sklearn.datasets._openml import (_open_openml_url,
                                      _arff,
                                      _DATA_FILE,
                                      _get_data_description_by_id,
                                      _get_local_path,
                                      _retry_with_clean_cache,
                                      _feature_to_dtype)
from sklearn.utils._testing import (assert_warns_message,
                                    assert_raise_message)
from sklearn.utils import is_scalar_nan
from sklearn.utils._testing import assert_allclose, assert_array_equal
from urllib.error import HTTPError
from sklearn.datasets.tests.test_common import check_return_X_y
from functools import partial


currdir = os.path.dirname(os.path.abspath(__file__))
# if True, urlopen will be monkey patched to only use local files
test_offline = True


def _test_features_list(data_id):
    # XXX Test is intended to verify/ensure correct decoding behavior
    # Not usable with sparse data or datasets that have columns marked as
    # {row_identifier, ignore}
    def decode_column(data_bunch, col_idx):
        col_name = data_bunch.feature_names[col_idx]
        if col_name in data_bunch.categories:
            # XXX: This would be faster with np.take, although it does not
            # handle missing values fast (also not with mode='wrap')
            cat = data_bunch.categories[col_name]
            result = [None if is_scalar_nan(idx) else cat[int(idx)]
                      for idx in data_bunch.data[:, col_idx]]
            return np.array(result, dtype='O')
        else:
            # non-nominal attribute
            return data_bunch.data[:, col_idx]

    data_bunch = fetch_openml(data_id=data_id, cache=False, target_column=None)

    # also obtain decoded arff
    data_description = _get_data_description_by_id(data_id, None)
    sparse = data_description['format'].lower() == 'sparse_arff'
    if sparse is True:
        raise ValueError('This test is not intended for sparse data, to keep '
                         'code relatively simple')
    url = _DATA_FILE.format(data_description['file_id'])
    with _open_openml_url(url, data_home=None) as f:
        data_arff = _arff.load((line.decode('utf-8') for line in f),
                               return_type=(_arff.COO if sparse
                                            else _arff.DENSE_GEN),
                               encode_nominal=False)

    data_downloaded = np.array(list(data_arff['data']), dtype='O')

    for i in range(len(data_bunch.feature_names)):
        # XXX: Test per column, as this makes it easier to avoid problems with
        # missing values

        np.testing.assert_array_equal(data_downloaded[:, i],
                                      decode_column(data_bunch, i))


def _fetch_dataset_from_openml(data_id, data_name, data_version,
                               target_column,
                               expected_observations, expected_features,
                               expected_missing,
                               expected_data_dtype, expected_target_dtype,
                               expect_sparse, compare_default_target):
    # fetches a dataset in three various ways from OpenML, using the
    # fetch_openml function, and does various checks on the validity of the
    # result. Note that this function can be mocked (by invoking
    # _monkey_patch_webbased_functions before invoking this function)
    data_by_name_id = fetch_openml(name=data_name, version=data_version,
                                   cache=False)
    assert int(data_by_name_id.details['id']) == data_id

    # Please note that cache=False is crucial, as the monkey patched files are
    # not consistent with reality
    fetch_openml(name=data_name, cache=False)
    # without specifying the version, there is no guarantee that the data id
    # will be the same

    # fetch with dataset id
    data_by_id = fetch_openml(data_id=data_id, cache=False,
                              target_column=target_column)
    assert data_by_id.details['name'] == data_name
    assert data_by_id.data.shape == (expected_observations, expected_features)
    if isinstance(target_column, str):
        # single target, so target is vector
        assert data_by_id.target.shape == (expected_observations, )
        assert data_by_id.target_names == [target_column]
    elif isinstance(target_column, list):
        # multi target, so target is array
        assert data_by_id.target.shape == (expected_observations,
                                           len(target_column))
        assert data_by_id.target_names == target_column
    assert data_by_id.data.dtype == expected_data_dtype
    assert data_by_id.target.dtype == expected_target_dtype
    assert len(data_by_id.feature_names) == expected_features
    for feature in data_by_id.feature_names:
        assert isinstance(feature, str)

    # TODO: pass in a list of expected nominal features
    for feature, categories in data_by_id.categories.items():
        feature_idx = data_by_id.feature_names.index(feature)
        values = np.unique(data_by_id.data[:, feature_idx])
        values = values[np.isfinite(values)]
        assert set(values) <= set(range(len(categories)))

    if compare_default_target:
        # check whether the data by id and data by id target are equal
        data_by_id_default = fetch_openml(data_id=data_id, cache=False)
        np.testing.assert_allclose(data_by_id.data, data_by_id_default.data)
        if data_by_id.target.dtype == np.float64:
            np.testing.assert_allclose(data_by_id.target,
                                       data_by_id_default.target)
        else:
            assert np.array_equal(data_by_id.target, data_by_id_default.target)

    if expect_sparse:
        assert isinstance(data_by_id.data, scipy.sparse.csr_matrix)
    else:
        assert isinstance(data_by_id.data, np.ndarray)
        # np.isnan doesn't work on CSR matrix
        assert (np.count_nonzero(np.isnan(data_by_id.data)) ==
                expected_missing)

    # test return_X_y option
    fetch_func = partial(fetch_openml, data_id=data_id, cache=False,
                         target_column=target_column)
    check_return_X_y(data_by_id, fetch_func)
    return data_by_id


def _monkey_patch_webbased_functions(context,
                                     data_id,
                                     gzip_response):
    # monkey patches the urlopen function. Important note: Do NOT use this
    # in combination with a regular cache directory, as the files that are
    # stored as cache should not be mixed up with real openml datasets
    url_prefix_data_description = "https://openml.org/api/v1/json/data/"
    url_prefix_data_features = "https://openml.org/api/v1/json/data/features/"
    url_prefix_download_data = "https://openml.org/data/v1/"
    url_prefix_data_list = "https://openml.org/api/v1/json/data/list/"

    path_suffix = '.gz'
    read_fn = gzip.open

    class MockHTTPResponse:
        def __init__(self, data, is_gzip):
            self.data = data
            self.is_gzip = is_gzip

        def read(self, amt=-1):
            return self.data.read(amt)

        def tell(self):
            return self.data.tell()

        def seek(self, pos, whence=0):
            return self.data.seek(pos, whence)

        def close(self):
            self.data.close()

        def info(self):
            if self.is_gzip:
                return {'Content-Encoding': 'gzip'}
            return {}

        def __iter__(self):
            return iter(self.data)

        def __enter__(self):
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            return False

    def _file_name(url, suffix):
        return (re.sub(r'\W', '-', url[len("https://openml.org/"):])
                + suffix + path_suffix)

    def _mock_urlopen_data_description(url, has_gzip_header):
        assert url.startswith(url_prefix_data_description)

        path = os.path.join(currdir, 'data', 'openml', str(data_id),
                            _file_name(url, '.json'))

        if has_gzip_header and gzip_response:
            fp = open(path, 'rb')
            return MockHTTPResponse(fp, True)
        else:
            fp = read_fn(path, 'rb')
            return MockHTTPResponse(fp, False)

    def _mock_urlopen_data_features(url, has_gzip_header):
        assert url.startswith(url_prefix_data_features)
        path = os.path.join(currdir, 'data', 'openml', str(data_id),
                            _file_name(url, '.json'))
        if has_gzip_header and gzip_response:
            fp = open(path, 'rb')
            return MockHTTPResponse(fp, True)
        else:
            fp = read_fn(path, 'rb')
            return MockHTTPResponse(fp, False)

    def _mock_urlopen_download_data(url, has_gzip_header):
        assert (url.startswith(url_prefix_download_data))

        path = os.path.join(currdir, 'data', 'openml', str(data_id),
                            _file_name(url, '.arff'))

        if has_gzip_header and gzip_response:
            fp = open(path, 'rb')
            return MockHTTPResponse(fp, True)
        else:
            fp = read_fn(path, 'rb')
            return MockHTTPResponse(fp, False)

    def _mock_urlopen_data_list(url, has_gzip_header):
        assert url.startswith(url_prefix_data_list)

        json_file_path = os.path.join(currdir, 'data', 'openml',
                                      str(data_id), _file_name(url, '.json'))
        # load the file itself, to simulate a http error
        json_data = json.loads(read_fn(json_file_path, 'rb').
                               read().decode('utf-8'))
        if 'error' in json_data:
            raise HTTPError(url=None, code=412,
                            msg='Simulated mock error',
                            hdrs=None, fp=None)

        if has_gzip_header:
            fp = open(json_file_path, 'rb')
            return MockHTTPResponse(fp, True)
        else:
            fp = read_fn(json_file_path, 'rb')
            return MockHTTPResponse(fp, False)

    def _mock_urlopen(request):
        url = request.get_full_url()
        has_gzip_header = request.get_header('Accept-encoding') == "gzip"
        if url.startswith(url_prefix_data_list):
            return _mock_urlopen_data_list(url, has_gzip_header)
        elif url.startswith(url_prefix_data_features):
            return _mock_urlopen_data_features(url, has_gzip_header)
        elif url.startswith(url_prefix_download_data):
            return _mock_urlopen_download_data(url, has_gzip_header)
        elif url.startswith(url_prefix_data_description):
            return _mock_urlopen_data_description(url, has_gzip_header)
        else:
            raise ValueError('Unknown mocking URL pattern: %s' % url)

    # XXX: Global variable
    if test_offline:
        context.setattr(sklearn.datasets._openml, 'urlopen', _mock_urlopen)


@pytest.mark.parametrize('feature, expected_dtype', [
    ({'data_type': 'string', 'number_of_missing_values': '0'}, object),
    ({'data_type': 'string', 'number_of_missing_values': '1'}, object),
    ({'data_type': 'numeric', 'number_of_missing_values': '0'}, np.float64),
    ({'data_type': 'numeric', 'number_of_missing_values': '1'}, np.float64),
    ({'data_type': 'real', 'number_of_missing_values': '0'}, np.float64),
    ({'data_type': 'real', 'number_of_missing_values': '1'}, np.float64),
    ({'data_type': 'integer', 'number_of_missing_values': '0'}, np.int64),
    ({'data_type': 'integer', 'number_of_missing_values': '1'}, np.float64),
    ({'data_type': 'nominal', 'number_of_missing_values': '0'}, 'category'),
    ({'data_type': 'nominal', 'number_of_missing_values': '1'}, 'category'),
])
def test_feature_to_dtype(feature, expected_dtype):
    assert _feature_to_dtype(feature) == expected_dtype


@pytest.mark.parametrize('feature', [
    {'data_type': 'datatime', 'number_of_missing_values': '0'}
])
def test_feature_to_dtype_error(feature):
    msg = 'Unsupported feature: {}'.format(feature)
    with pytest.raises(ValueError, match=msg):
        _feature_to_dtype(feature)


def test_fetch_openml_iris_pandas(monkeypatch):
    # classification dataset with numeric only columns
    pd = pytest.importorskip('pandas')
    CategoricalDtype = pd.api.types.CategoricalDtype
    data_id = 61
    data_shape = (150, 4)
    target_shape = (150, )
    frame_shape = (150, 5)

    target_dtype = CategoricalDtype(['Iris-setosa', 'Iris-versicolor',
                                     'Iris-virginica'])
    data_dtypes = [np.float64] * 4
    data_names = ['sepallength', 'sepalwidth', 'petallength', 'petalwidth']
    target_name = 'class'

    _monkey_patch_webbased_functions(monkeypatch, data_id, True)

    bunch = fetch_openml(data_id=data_id, as_frame=True, cache=False)
    data = bunch.data
    target = bunch.target
    frame = bunch.frame

    assert isinstance(data, pd.DataFrame)
    assert np.all(data.dtypes == data_dtypes)
    assert data.shape == data_shape
    assert np.all(data.columns == data_names)
    assert np.all(bunch.feature_names == data_names)
    assert bunch.target_names == [target_name]

    assert isinstance(target, pd.Series)
    assert target.dtype == target_dtype
    assert target.shape == target_shape
    assert target.name == target_name
    assert target.index.is_unique

    assert isinstance(frame, pd.DataFrame)
    assert frame.shape == frame_shape
    assert np.all(frame.dtypes == data_dtypes + [target_dtype])
    assert frame.index.is_unique


def test_fetch_openml_iris_pandas_equal_to_no_frame(monkeypatch):
    # as_frame = True returns the same underlying data as as_frame = False
    pytest.importorskip('pandas')
    data_id = 61

    _monkey_patch_webbased_functions(monkeypatch, data_id, True)
Loading ...