136 lines
4.3 KiB
Python
136 lines
4.3 KiB
Python
|
"""Test loaders for common functionality."""
|
||
|
import inspect
|
||
|
import os
|
||
|
|
||
|
import numpy as np
|
||
|
import pytest
|
||
|
|
||
|
import sklearn.datasets
|
||
|
|
||
|
|
||
|
def is_pillow_installed():
|
||
|
try:
|
||
|
import PIL # noqa
|
||
|
|
||
|
return True
|
||
|
except ImportError:
|
||
|
return False
|
||
|
|
||
|
|
||
|
FETCH_PYTEST_MARKERS = {
|
||
|
"return_X_y": {
|
||
|
"fetch_20newsgroups": pytest.mark.xfail(
|
||
|
reason="X is a list and does not have a shape argument"
|
||
|
),
|
||
|
"fetch_openml": pytest.mark.xfail(
|
||
|
reason="fetch_opeml requires a dataset name or id"
|
||
|
),
|
||
|
"fetch_lfw_people": pytest.mark.skipif(
|
||
|
not is_pillow_installed(), reason="pillow is not installed"
|
||
|
),
|
||
|
},
|
||
|
"as_frame": {
|
||
|
"fetch_openml": pytest.mark.xfail(
|
||
|
reason="fetch_opeml requires a dataset name or id"
|
||
|
),
|
||
|
},
|
||
|
}
|
||
|
|
||
|
|
||
|
def check_pandas_dependency_message(fetch_func):
|
||
|
try:
|
||
|
import pandas # noqa
|
||
|
|
||
|
pytest.skip("This test requires pandas to not be installed")
|
||
|
except ImportError:
|
||
|
# Check that pandas is imported lazily and that an informative error
|
||
|
# message is raised when pandas is missing:
|
||
|
name = fetch_func.__name__
|
||
|
expected_msg = f"{name} with as_frame=True requires pandas"
|
||
|
with pytest.raises(ImportError, match=expected_msg):
|
||
|
fetch_func(as_frame=True)
|
||
|
|
||
|
|
||
|
def check_return_X_y(bunch, dataset_func):
|
||
|
X_y_tuple = dataset_func(return_X_y=True)
|
||
|
assert isinstance(X_y_tuple, tuple)
|
||
|
assert X_y_tuple[0].shape == bunch.data.shape
|
||
|
assert X_y_tuple[1].shape == bunch.target.shape
|
||
|
|
||
|
|
||
|
def check_as_frame(
|
||
|
bunch, dataset_func, expected_data_dtype=None, expected_target_dtype=None
|
||
|
):
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
frame_bunch = dataset_func(as_frame=True)
|
||
|
assert hasattr(frame_bunch, "frame")
|
||
|
assert isinstance(frame_bunch.frame, pd.DataFrame)
|
||
|
assert isinstance(frame_bunch.data, pd.DataFrame)
|
||
|
assert frame_bunch.data.shape == bunch.data.shape
|
||
|
if frame_bunch.target.ndim > 1:
|
||
|
assert isinstance(frame_bunch.target, pd.DataFrame)
|
||
|
else:
|
||
|
assert isinstance(frame_bunch.target, pd.Series)
|
||
|
assert frame_bunch.target.shape[0] == bunch.target.shape[0]
|
||
|
if expected_data_dtype is not None:
|
||
|
assert np.all(frame_bunch.data.dtypes == expected_data_dtype)
|
||
|
if expected_target_dtype is not None:
|
||
|
assert np.all(frame_bunch.target.dtypes == expected_target_dtype)
|
||
|
|
||
|
# Test for return_X_y and as_frame=True
|
||
|
frame_X, frame_y = dataset_func(as_frame=True, return_X_y=True)
|
||
|
assert isinstance(frame_X, pd.DataFrame)
|
||
|
if frame_y.ndim > 1:
|
||
|
assert isinstance(frame_X, pd.DataFrame)
|
||
|
else:
|
||
|
assert isinstance(frame_y, pd.Series)
|
||
|
|
||
|
|
||
|
def _skip_network_tests():
|
||
|
return os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "1"
|
||
|
|
||
|
|
||
|
def _generate_func_supporting_param(param, dataset_type=("load", "fetch")):
|
||
|
markers_fetch = FETCH_PYTEST_MARKERS.get(param, {})
|
||
|
for name, obj in inspect.getmembers(sklearn.datasets):
|
||
|
if not inspect.isfunction(obj):
|
||
|
continue
|
||
|
|
||
|
is_dataset_type = any([name.startswith(t) for t in dataset_type])
|
||
|
is_support_param = param in inspect.signature(obj).parameters
|
||
|
if is_dataset_type and is_support_param:
|
||
|
# check if we should skip if we don't have network support
|
||
|
marks = [
|
||
|
pytest.mark.skipif(
|
||
|
condition=name.startswith("fetch") and _skip_network_tests(),
|
||
|
reason="Skip because fetcher requires internet network",
|
||
|
)
|
||
|
]
|
||
|
if name in markers_fetch:
|
||
|
marks.append(markers_fetch[name])
|
||
|
|
||
|
yield pytest.param(name, obj, marks=marks)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"name, dataset_func", _generate_func_supporting_param("return_X_y")
|
||
|
)
|
||
|
def test_common_check_return_X_y(name, dataset_func):
|
||
|
bunch = dataset_func()
|
||
|
check_return_X_y(bunch, dataset_func)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"name, dataset_func", _generate_func_supporting_param("as_frame")
|
||
|
)
|
||
|
def test_common_check_as_frame(name, dataset_func):
|
||
|
bunch = dataset_func()
|
||
|
check_as_frame(bunch, dataset_func)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"name, dataset_func", _generate_func_supporting_param("as_frame")
|
||
|
)
|
||
|
def test_common_check_pandas_dependency(name, dataset_func):
|
||
|
check_pandas_dependency_message(dataset_func)
|