924 lines
27 KiB
Python
924 lines
27 KiB
Python
|
import atexit
|
||
|
import os
|
||
|
import unittest
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
import pytest
|
||
|
from scipy import sparse
|
||
|
|
||
|
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
|
||
|
from sklearn.tree import DecisionTreeClassifier
|
||
|
from sklearn.utils import _IS_WASM
|
||
|
from sklearn.utils._testing import (
|
||
|
TempMemmap,
|
||
|
_convert_container,
|
||
|
_delete_folder,
|
||
|
_get_warnings_filters_info_list,
|
||
|
assert_allclose,
|
||
|
assert_allclose_dense_sparse,
|
||
|
assert_no_warnings,
|
||
|
assert_raise_message,
|
||
|
assert_raises,
|
||
|
assert_raises_regex,
|
||
|
assert_run_python_script_without_output,
|
||
|
check_docstring_parameters,
|
||
|
create_memmap_backed_data,
|
||
|
ignore_warnings,
|
||
|
raises,
|
||
|
set_random_state,
|
||
|
turn_warnings_into_errors,
|
||
|
)
|
||
|
from sklearn.utils.deprecation import deprecated
|
||
|
from sklearn.utils.fixes import (
|
||
|
CSC_CONTAINERS,
|
||
|
CSR_CONTAINERS,
|
||
|
parse_version,
|
||
|
sp_version,
|
||
|
)
|
||
|
from sklearn.utils.metaestimators import available_if
|
||
|
|
||
|
|
||
|
def test_set_random_state():
|
||
|
lda = LinearDiscriminantAnalysis()
|
||
|
tree = DecisionTreeClassifier()
|
||
|
# Linear Discriminant Analysis doesn't have random state: smoke test
|
||
|
set_random_state(lda, 3)
|
||
|
set_random_state(tree, 3)
|
||
|
assert tree.random_state == 3
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("csr_container", CSC_CONTAINERS)
|
||
|
def test_assert_allclose_dense_sparse(csr_container):
|
||
|
x = np.arange(9).reshape(3, 3)
|
||
|
msg = "Not equal to tolerance "
|
||
|
y = csr_container(x)
|
||
|
for X in [x, y]:
|
||
|
# basic compare
|
||
|
with pytest.raises(AssertionError, match=msg):
|
||
|
assert_allclose_dense_sparse(X, X * 2)
|
||
|
assert_allclose_dense_sparse(X, X)
|
||
|
|
||
|
with pytest.raises(ValueError, match="Can only compare two sparse"):
|
||
|
assert_allclose_dense_sparse(x, y)
|
||
|
|
||
|
A = sparse.diags(np.ones(5), offsets=0).tocsr()
|
||
|
B = csr_container(np.ones((1, 5)))
|
||
|
with pytest.raises(AssertionError, match="Arrays are not equal"):
|
||
|
assert_allclose_dense_sparse(B, A)
|
||
|
|
||
|
|
||
|
def test_assert_raises_msg():
|
||
|
with assert_raises_regex(AssertionError, "Hello world"):
|
||
|
with assert_raises(ValueError, msg="Hello world"):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def test_assert_raise_message():
|
||
|
def _raise_ValueError(message):
|
||
|
raise ValueError(message)
|
||
|
|
||
|
def _no_raise():
|
||
|
pass
|
||
|
|
||
|
assert_raise_message(ValueError, "test", _raise_ValueError, "test")
|
||
|
|
||
|
assert_raises(
|
||
|
AssertionError,
|
||
|
assert_raise_message,
|
||
|
ValueError,
|
||
|
"something else",
|
||
|
_raise_ValueError,
|
||
|
"test",
|
||
|
)
|
||
|
|
||
|
assert_raises(
|
||
|
ValueError,
|
||
|
assert_raise_message,
|
||
|
TypeError,
|
||
|
"something else",
|
||
|
_raise_ValueError,
|
||
|
"test",
|
||
|
)
|
||
|
|
||
|
assert_raises(AssertionError, assert_raise_message, ValueError, "test", _no_raise)
|
||
|
|
||
|
# multiple exceptions in a tuple
|
||
|
assert_raises(
|
||
|
AssertionError,
|
||
|
assert_raise_message,
|
||
|
(ValueError, AttributeError),
|
||
|
"test",
|
||
|
_no_raise,
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_ignore_warning():
|
||
|
# This check that ignore_warning decorator and context manager are working
|
||
|
# as expected
|
||
|
def _warning_function():
|
||
|
warnings.warn("deprecation warning", DeprecationWarning)
|
||
|
|
||
|
def _multiple_warning_function():
|
||
|
warnings.warn("deprecation warning", DeprecationWarning)
|
||
|
warnings.warn("deprecation warning")
|
||
|
|
||
|
# Check the function directly
|
||
|
assert_no_warnings(ignore_warnings(_warning_function))
|
||
|
assert_no_warnings(ignore_warnings(_warning_function, category=DeprecationWarning))
|
||
|
with pytest.warns(DeprecationWarning):
|
||
|
ignore_warnings(_warning_function, category=UserWarning)()
|
||
|
|
||
|
with pytest.warns() as record:
|
||
|
ignore_warnings(_multiple_warning_function, category=FutureWarning)()
|
||
|
assert len(record) == 2
|
||
|
assert isinstance(record[0].message, DeprecationWarning)
|
||
|
assert isinstance(record[1].message, UserWarning)
|
||
|
|
||
|
with pytest.warns() as record:
|
||
|
ignore_warnings(_multiple_warning_function, category=UserWarning)()
|
||
|
assert len(record) == 1
|
||
|
assert isinstance(record[0].message, DeprecationWarning)
|
||
|
|
||
|
assert_no_warnings(
|
||
|
ignore_warnings(_warning_function, category=(DeprecationWarning, UserWarning))
|
||
|
)
|
||
|
|
||
|
# Check the decorator
|
||
|
@ignore_warnings
|
||
|
def decorator_no_warning():
|
||
|
_warning_function()
|
||
|
_multiple_warning_function()
|
||
|
|
||
|
@ignore_warnings(category=(DeprecationWarning, UserWarning))
|
||
|
def decorator_no_warning_multiple():
|
||
|
_multiple_warning_function()
|
||
|
|
||
|
@ignore_warnings(category=DeprecationWarning)
|
||
|
def decorator_no_deprecation_warning():
|
||
|
_warning_function()
|
||
|
|
||
|
@ignore_warnings(category=UserWarning)
|
||
|
def decorator_no_user_warning():
|
||
|
_warning_function()
|
||
|
|
||
|
@ignore_warnings(category=DeprecationWarning)
|
||
|
def decorator_no_deprecation_multiple_warning():
|
||
|
_multiple_warning_function()
|
||
|
|
||
|
@ignore_warnings(category=UserWarning)
|
||
|
def decorator_no_user_multiple_warning():
|
||
|
_multiple_warning_function()
|
||
|
|
||
|
assert_no_warnings(decorator_no_warning)
|
||
|
assert_no_warnings(decorator_no_warning_multiple)
|
||
|
assert_no_warnings(decorator_no_deprecation_warning)
|
||
|
with pytest.warns(DeprecationWarning):
|
||
|
decorator_no_user_warning()
|
||
|
with pytest.warns(UserWarning):
|
||
|
decorator_no_deprecation_multiple_warning()
|
||
|
with pytest.warns(DeprecationWarning):
|
||
|
decorator_no_user_multiple_warning()
|
||
|
|
||
|
# Check the context manager
|
||
|
def context_manager_no_warning():
|
||
|
with ignore_warnings():
|
||
|
_warning_function()
|
||
|
|
||
|
def context_manager_no_warning_multiple():
|
||
|
with ignore_warnings(category=(DeprecationWarning, UserWarning)):
|
||
|
_multiple_warning_function()
|
||
|
|
||
|
def context_manager_no_deprecation_warning():
|
||
|
with ignore_warnings(category=DeprecationWarning):
|
||
|
_warning_function()
|
||
|
|
||
|
def context_manager_no_user_warning():
|
||
|
with ignore_warnings(category=UserWarning):
|
||
|
_warning_function()
|
||
|
|
||
|
def context_manager_no_deprecation_multiple_warning():
|
||
|
with ignore_warnings(category=DeprecationWarning):
|
||
|
_multiple_warning_function()
|
||
|
|
||
|
def context_manager_no_user_multiple_warning():
|
||
|
with ignore_warnings(category=UserWarning):
|
||
|
_multiple_warning_function()
|
||
|
|
||
|
assert_no_warnings(context_manager_no_warning)
|
||
|
assert_no_warnings(context_manager_no_warning_multiple)
|
||
|
assert_no_warnings(context_manager_no_deprecation_warning)
|
||
|
with pytest.warns(DeprecationWarning):
|
||
|
context_manager_no_user_warning()
|
||
|
with pytest.warns(UserWarning):
|
||
|
context_manager_no_deprecation_multiple_warning()
|
||
|
with pytest.warns(DeprecationWarning):
|
||
|
context_manager_no_user_multiple_warning()
|
||
|
|
||
|
# Check that passing warning class as first positional argument
|
||
|
warning_class = UserWarning
|
||
|
match = "'obj' should be a callable.+you should use 'category=UserWarning'"
|
||
|
|
||
|
with pytest.raises(ValueError, match=match):
|
||
|
silence_warnings_func = ignore_warnings(warning_class)(_warning_function)
|
||
|
silence_warnings_func()
|
||
|
|
||
|
with pytest.raises(ValueError, match=match):
|
||
|
|
||
|
@ignore_warnings(warning_class)
|
||
|
def test():
|
||
|
pass
|
||
|
|
||
|
|
||
|
class TestWarns(unittest.TestCase):
|
||
|
def test_warn(self):
|
||
|
def f():
|
||
|
warnings.warn("yo")
|
||
|
return 3
|
||
|
|
||
|
with pytest.raises(AssertionError):
|
||
|
assert_no_warnings(f)
|
||
|
assert assert_no_warnings(lambda x: x, 1) == 1
|
||
|
|
||
|
|
||
|
# Tests for docstrings:
|
||
|
|
||
|
|
||
|
def f_ok(a, b):
|
||
|
"""Function f
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
a : int
|
||
|
Parameter a
|
||
|
b : float
|
||
|
Parameter b
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
c : list
|
||
|
Parameter c
|
||
|
"""
|
||
|
c = a + b
|
||
|
return c
|
||
|
|
||
|
|
||
|
def f_bad_sections(a, b):
|
||
|
"""Function f
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
a : int
|
||
|
Parameter a
|
||
|
b : float
|
||
|
Parameter b
|
||
|
|
||
|
Results
|
||
|
-------
|
||
|
c : list
|
||
|
Parameter c
|
||
|
"""
|
||
|
c = a + b
|
||
|
return c
|
||
|
|
||
|
|
||
|
def f_bad_order(b, a):
|
||
|
"""Function f
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
a : int
|
||
|
Parameter a
|
||
|
b : float
|
||
|
Parameter b
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
c : list
|
||
|
Parameter c
|
||
|
"""
|
||
|
c = a + b
|
||
|
return c
|
||
|
|
||
|
|
||
|
def f_too_many_param_docstring(a, b):
|
||
|
"""Function f
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
a : int
|
||
|
Parameter a
|
||
|
b : int
|
||
|
Parameter b
|
||
|
c : int
|
||
|
Parameter c
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
d : list
|
||
|
Parameter c
|
||
|
"""
|
||
|
d = a + b
|
||
|
return d
|
||
|
|
||
|
|
||
|
def f_missing(a, b):
|
||
|
"""Function f
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
a : int
|
||
|
Parameter a
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
c : list
|
||
|
Parameter c
|
||
|
"""
|
||
|
c = a + b
|
||
|
return c
|
||
|
|
||
|
|
||
|
def f_check_param_definition(a, b, c, d, e):
|
||
|
"""Function f
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
a: int
|
||
|
Parameter a
|
||
|
b:
|
||
|
Parameter b
|
||
|
c :
|
||
|
This is parsed correctly in numpydoc 1.2
|
||
|
d:int
|
||
|
Parameter d
|
||
|
e
|
||
|
No typespec is allowed without colon
|
||
|
"""
|
||
|
return a + b + c + d
|
||
|
|
||
|
|
||
|
class Klass:
|
||
|
def f_missing(self, X, y):
|
||
|
pass
|
||
|
|
||
|
def f_bad_sections(self, X, y):
|
||
|
"""Function f
|
||
|
|
||
|
Parameter
|
||
|
---------
|
||
|
a : int
|
||
|
Parameter a
|
||
|
b : float
|
||
|
Parameter b
|
||
|
|
||
|
Results
|
||
|
-------
|
||
|
c : list
|
||
|
Parameter c
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
class MockEst:
|
||
|
def __init__(self):
|
||
|
"""MockEstimator"""
|
||
|
|
||
|
def fit(self, X, y):
|
||
|
return X
|
||
|
|
||
|
def predict(self, X):
|
||
|
return X
|
||
|
|
||
|
def predict_proba(self, X):
|
||
|
return X
|
||
|
|
||
|
def score(self, X):
|
||
|
return 1.0
|
||
|
|
||
|
|
||
|
class MockMetaEstimator:
|
||
|
def __init__(self, delegate):
|
||
|
"""MetaEstimator to check if doctest on delegated methods work.
|
||
|
|
||
|
Parameters
|
||
|
---------
|
||
|
delegate : estimator
|
||
|
Delegated estimator.
|
||
|
"""
|
||
|
self.delegate = delegate
|
||
|
|
||
|
@available_if(lambda self: hasattr(self.delegate, "predict"))
|
||
|
def predict(self, X):
|
||
|
"""This is available only if delegate has predict.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
y : ndarray
|
||
|
Parameter y
|
||
|
"""
|
||
|
return self.delegate.predict(X)
|
||
|
|
||
|
@available_if(lambda self: hasattr(self.delegate, "score"))
|
||
|
@deprecated("Testing a deprecated delegated method")
|
||
|
def score(self, X):
|
||
|
"""This is available only if delegate has score.
|
||
|
|
||
|
Parameters
|
||
|
---------
|
||
|
y : ndarray
|
||
|
Parameter y
|
||
|
"""
|
||
|
|
||
|
@available_if(lambda self: hasattr(self.delegate, "predict_proba"))
|
||
|
def predict_proba(self, X):
|
||
|
"""This is available only if delegate has predict_proba.
|
||
|
|
||
|
Parameters
|
||
|
---------
|
||
|
X : ndarray
|
||
|
Parameter X
|
||
|
"""
|
||
|
return X
|
||
|
|
||
|
@deprecated("Testing deprecated function with wrong params")
|
||
|
def fit(self, X, y):
|
||
|
"""Incorrect docstring but should not be tested"""
|
||
|
|
||
|
|
||
|
def test_check_docstring_parameters():
|
||
|
pytest.importorskip(
|
||
|
"numpydoc",
|
||
|
reason="numpydoc is required to test the docstrings",
|
||
|
minversion="1.2.0",
|
||
|
)
|
||
|
|
||
|
incorrect = check_docstring_parameters(f_ok)
|
||
|
assert incorrect == []
|
||
|
incorrect = check_docstring_parameters(f_ok, ignore=["b"])
|
||
|
assert incorrect == []
|
||
|
incorrect = check_docstring_parameters(f_missing, ignore=["b"])
|
||
|
assert incorrect == []
|
||
|
with pytest.raises(RuntimeError, match="Unknown section Results"):
|
||
|
check_docstring_parameters(f_bad_sections)
|
||
|
with pytest.raises(RuntimeError, match="Unknown section Parameter"):
|
||
|
check_docstring_parameters(Klass.f_bad_sections)
|
||
|
|
||
|
incorrect = check_docstring_parameters(f_check_param_definition)
|
||
|
mock_meta = MockMetaEstimator(delegate=MockEst())
|
||
|
mock_meta_name = mock_meta.__class__.__name__
|
||
|
assert incorrect == [
|
||
|
(
|
||
|
"sklearn.utils.tests.test_testing.f_check_param_definition There "
|
||
|
"was no space between the param name and colon ('a: int')"
|
||
|
),
|
||
|
(
|
||
|
"sklearn.utils.tests.test_testing.f_check_param_definition There "
|
||
|
"was no space between the param name and colon ('b:')"
|
||
|
),
|
||
|
(
|
||
|
"sklearn.utils.tests.test_testing.f_check_param_definition There "
|
||
|
"was no space between the param name and colon ('d:int')"
|
||
|
),
|
||
|
]
|
||
|
|
||
|
messages = [
|
||
|
[
|
||
|
"In function: sklearn.utils.tests.test_testing.f_bad_order",
|
||
|
(
|
||
|
"There's a parameter name mismatch in function docstring w.r.t."
|
||
|
" function signature, at index 0 diff: 'b' != 'a'"
|
||
|
),
|
||
|
"Full diff:",
|
||
|
"- ['b', 'a']",
|
||
|
"+ ['a', 'b']",
|
||
|
],
|
||
|
[
|
||
|
"In function: "
|
||
|
+ "sklearn.utils.tests.test_testing.f_too_many_param_docstring",
|
||
|
(
|
||
|
"Parameters in function docstring have more items w.r.t. function"
|
||
|
" signature, first extra item: c"
|
||
|
),
|
||
|
"Full diff:",
|
||
|
"- ['a', 'b']",
|
||
|
"+ ['a', 'b', 'c']",
|
||
|
"? +++++",
|
||
|
],
|
||
|
[
|
||
|
"In function: sklearn.utils.tests.test_testing.f_missing",
|
||
|
(
|
||
|
"Parameters in function docstring have less items w.r.t. function"
|
||
|
" signature, first missing item: b"
|
||
|
),
|
||
|
"Full diff:",
|
||
|
"- ['a', 'b']",
|
||
|
"+ ['a']",
|
||
|
],
|
||
|
[
|
||
|
"In function: sklearn.utils.tests.test_testing.Klass.f_missing",
|
||
|
(
|
||
|
"Parameters in function docstring have less items w.r.t. function"
|
||
|
" signature, first missing item: X"
|
||
|
),
|
||
|
"Full diff:",
|
||
|
"- ['X', 'y']",
|
||
|
"+ []",
|
||
|
],
|
||
|
[
|
||
|
"In function: "
|
||
|
+ f"sklearn.utils.tests.test_testing.{mock_meta_name}.predict",
|
||
|
(
|
||
|
"There's a parameter name mismatch in function docstring w.r.t."
|
||
|
" function signature, at index 0 diff: 'X' != 'y'"
|
||
|
),
|
||
|
"Full diff:",
|
||
|
"- ['X']",
|
||
|
"? ^",
|
||
|
"+ ['y']",
|
||
|
"? ^",
|
||
|
],
|
||
|
[
|
||
|
"In function: "
|
||
|
+ f"sklearn.utils.tests.test_testing.{mock_meta_name}."
|
||
|
+ "predict_proba",
|
||
|
"potentially wrong underline length... ",
|
||
|
"Parameters ",
|
||
|
"--------- in ",
|
||
|
],
|
||
|
[
|
||
|
"In function: "
|
||
|
+ f"sklearn.utils.tests.test_testing.{mock_meta_name}.score",
|
||
|
"potentially wrong underline length... ",
|
||
|
"Parameters ",
|
||
|
"--------- in ",
|
||
|
],
|
||
|
[
|
||
|
"In function: " + f"sklearn.utils.tests.test_testing.{mock_meta_name}.fit",
|
||
|
(
|
||
|
"Parameters in function docstring have less items w.r.t. function"
|
||
|
" signature, first missing item: X"
|
||
|
),
|
||
|
"Full diff:",
|
||
|
"- ['X', 'y']",
|
||
|
"+ []",
|
||
|
],
|
||
|
]
|
||
|
|
||
|
for msg, f in zip(
|
||
|
messages,
|
||
|
[
|
||
|
f_bad_order,
|
||
|
f_too_many_param_docstring,
|
||
|
f_missing,
|
||
|
Klass.f_missing,
|
||
|
mock_meta.predict,
|
||
|
mock_meta.predict_proba,
|
||
|
mock_meta.score,
|
||
|
mock_meta.fit,
|
||
|
],
|
||
|
):
|
||
|
incorrect = check_docstring_parameters(f)
|
||
|
assert msg == incorrect, '\n"%s"\n not in \n"%s"' % (msg, incorrect)
|
||
|
|
||
|
|
||
|
class RegistrationCounter:
|
||
|
def __init__(self):
|
||
|
self.nb_calls = 0
|
||
|
|
||
|
def __call__(self, to_register_func):
|
||
|
self.nb_calls += 1
|
||
|
assert to_register_func.func is _delete_folder
|
||
|
|
||
|
|
||
|
def check_memmap(input_array, mmap_data, mmap_mode="r"):
|
||
|
assert isinstance(mmap_data, np.memmap)
|
||
|
writeable = mmap_mode != "r"
|
||
|
assert mmap_data.flags.writeable is writeable
|
||
|
np.testing.assert_array_equal(input_array, mmap_data)
|
||
|
|
||
|
|
||
|
def test_tempmemmap(monkeypatch):
|
||
|
registration_counter = RegistrationCounter()
|
||
|
monkeypatch.setattr(atexit, "register", registration_counter)
|
||
|
|
||
|
input_array = np.ones(3)
|
||
|
with TempMemmap(input_array) as data:
|
||
|
check_memmap(input_array, data)
|
||
|
temp_folder = os.path.dirname(data.filename)
|
||
|
if os.name != "nt":
|
||
|
assert not os.path.exists(temp_folder)
|
||
|
assert registration_counter.nb_calls == 1
|
||
|
|
||
|
mmap_mode = "r+"
|
||
|
with TempMemmap(input_array, mmap_mode=mmap_mode) as data:
|
||
|
check_memmap(input_array, data, mmap_mode=mmap_mode)
|
||
|
temp_folder = os.path.dirname(data.filename)
|
||
|
if os.name != "nt":
|
||
|
assert not os.path.exists(temp_folder)
|
||
|
assert registration_counter.nb_calls == 2
|
||
|
|
||
|
|
||
|
@pytest.mark.xfail(_IS_WASM, reason="memmap not fully supported")
|
||
|
def test_create_memmap_backed_data(monkeypatch):
|
||
|
registration_counter = RegistrationCounter()
|
||
|
monkeypatch.setattr(atexit, "register", registration_counter)
|
||
|
|
||
|
input_array = np.ones(3)
|
||
|
data = create_memmap_backed_data(input_array)
|
||
|
check_memmap(input_array, data)
|
||
|
assert registration_counter.nb_calls == 1
|
||
|
|
||
|
data, folder = create_memmap_backed_data(input_array, return_folder=True)
|
||
|
check_memmap(input_array, data)
|
||
|
assert folder == os.path.dirname(data.filename)
|
||
|
assert registration_counter.nb_calls == 2
|
||
|
|
||
|
mmap_mode = "r+"
|
||
|
data = create_memmap_backed_data(input_array, mmap_mode=mmap_mode)
|
||
|
check_memmap(input_array, data, mmap_mode)
|
||
|
assert registration_counter.nb_calls == 3
|
||
|
|
||
|
input_list = [input_array, input_array + 1, input_array + 2]
|
||
|
mmap_data_list = create_memmap_backed_data(input_list)
|
||
|
for input_array, data in zip(input_list, mmap_data_list):
|
||
|
check_memmap(input_array, data)
|
||
|
assert registration_counter.nb_calls == 4
|
||
|
|
||
|
output_data, other = create_memmap_backed_data([input_array, "not-an-array"])
|
||
|
check_memmap(input_array, output_data)
|
||
|
assert other == "not-an-array"
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"constructor_name, container_type",
|
||
|
[
|
||
|
("list", list),
|
||
|
("tuple", tuple),
|
||
|
("array", np.ndarray),
|
||
|
("sparse", sparse.csr_matrix),
|
||
|
# using `zip` will only keep the available sparse containers
|
||
|
# depending of the installed SciPy version
|
||
|
*zip(["sparse_csr", "sparse_csr_array"], CSR_CONTAINERS),
|
||
|
*zip(["sparse_csc", "sparse_csc_array"], CSC_CONTAINERS),
|
||
|
("dataframe", lambda: pytest.importorskip("pandas").DataFrame),
|
||
|
("series", lambda: pytest.importorskip("pandas").Series),
|
||
|
("index", lambda: pytest.importorskip("pandas").Index),
|
||
|
("slice", slice),
|
||
|
],
|
||
|
)
|
||
|
@pytest.mark.parametrize(
|
||
|
"dtype, superdtype",
|
||
|
[
|
||
|
(np.int32, np.integer),
|
||
|
(np.int64, np.integer),
|
||
|
(np.float32, np.floating),
|
||
|
(np.float64, np.floating),
|
||
|
],
|
||
|
)
|
||
|
def test_convert_container(
|
||
|
constructor_name,
|
||
|
container_type,
|
||
|
dtype,
|
||
|
superdtype,
|
||
|
):
|
||
|
"""Check that we convert the container to the right type of array with the
|
||
|
right data type."""
|
||
|
if constructor_name in ("dataframe", "series", "index"):
|
||
|
# delay the import of pandas within the function to only skip this test
|
||
|
# instead of the whole file
|
||
|
container_type = container_type()
|
||
|
container = [0, 1]
|
||
|
|
||
|
container_converted = _convert_container(
|
||
|
container,
|
||
|
constructor_name,
|
||
|
dtype=dtype,
|
||
|
)
|
||
|
assert isinstance(container_converted, container_type)
|
||
|
|
||
|
if constructor_name in ("list", "tuple", "index"):
|
||
|
# list and tuple will use Python class dtype: int, float
|
||
|
# pandas index will always use high precision: np.int64 and np.float64
|
||
|
assert np.issubdtype(type(container_converted[0]), superdtype)
|
||
|
elif hasattr(container_converted, "dtype"):
|
||
|
assert container_converted.dtype == dtype
|
||
|
elif hasattr(container_converted, "dtypes"):
|
||
|
assert container_converted.dtypes[0] == dtype
|
||
|
|
||
|
|
||
|
def test_convert_container_categories_pandas():
|
||
|
pytest.importorskip("pandas")
|
||
|
df = _convert_container(
|
||
|
[["x"]], "dataframe", ["A"], categorical_feature_names=["A"]
|
||
|
)
|
||
|
assert df.dtypes.iloc[0] == "category"
|
||
|
|
||
|
|
||
|
def test_convert_container_categories_polars():
|
||
|
pl = pytest.importorskip("polars")
|
||
|
df = _convert_container([["x"]], "polars", ["A"], categorical_feature_names=["A"])
|
||
|
assert df.schema["A"] == pl.Categorical()
|
||
|
|
||
|
|
||
|
def test_convert_container_categories_pyarrow():
|
||
|
pa = pytest.importorskip("pyarrow")
|
||
|
df = _convert_container([["x"]], "pyarrow", ["A"], categorical_feature_names=["A"])
|
||
|
assert type(df.schema[0].type) is pa.DictionaryType
|
||
|
|
||
|
|
||
|
@pytest.mark.skipif(
|
||
|
sp_version >= parse_version("1.8"),
|
||
|
reason="sparse arrays are available as of scipy 1.8.0",
|
||
|
)
|
||
|
@pytest.mark.parametrize("constructor_name", ["sparse_csr_array", "sparse_csc_array"])
|
||
|
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
|
||
|
def test_convert_container_raise_when_sparray_not_available(constructor_name, dtype):
|
||
|
"""Check that if we convert to sparse array but sparse array are not supported
|
||
|
(scipy<1.8.0), we should raise an explicit error."""
|
||
|
container = [0, 1]
|
||
|
|
||
|
with pytest.raises(
|
||
|
ValueError,
|
||
|
match=f"only available with scipy>=1.8.0, got {sp_version}",
|
||
|
):
|
||
|
_convert_container(container, constructor_name, dtype=dtype)
|
||
|
|
||
|
|
||
|
def test_raises():
|
||
|
# Tests for the raises context manager
|
||
|
|
||
|
# Proper type, no match
|
||
|
with raises(TypeError):
|
||
|
raise TypeError()
|
||
|
|
||
|
# Proper type, proper match
|
||
|
with raises(TypeError, match="how are you") as cm:
|
||
|
raise TypeError("hello how are you")
|
||
|
assert cm.raised_and_matched
|
||
|
|
||
|
# Proper type, proper match with multiple patterns
|
||
|
with raises(TypeError, match=["not this one", "how are you"]) as cm:
|
||
|
raise TypeError("hello how are you")
|
||
|
assert cm.raised_and_matched
|
||
|
|
||
|
# bad type, no match
|
||
|
with pytest.raises(ValueError, match="this will be raised"):
|
||
|
with raises(TypeError) as cm:
|
||
|
raise ValueError("this will be raised")
|
||
|
assert not cm.raised_and_matched
|
||
|
|
||
|
# Bad type, no match, with a err_msg
|
||
|
with pytest.raises(AssertionError, match="the failure message"):
|
||
|
with raises(TypeError, err_msg="the failure message") as cm:
|
||
|
raise ValueError()
|
||
|
assert not cm.raised_and_matched
|
||
|
|
||
|
# bad type, with match (is ignored anyway)
|
||
|
with pytest.raises(ValueError, match="this will be raised"):
|
||
|
with raises(TypeError, match="this is ignored") as cm:
|
||
|
raise ValueError("this will be raised")
|
||
|
assert not cm.raised_and_matched
|
||
|
|
||
|
# proper type but bad match
|
||
|
with pytest.raises(
|
||
|
AssertionError, match="should contain one of the following patterns"
|
||
|
):
|
||
|
with raises(TypeError, match="hello") as cm:
|
||
|
raise TypeError("Bad message")
|
||
|
assert not cm.raised_and_matched
|
||
|
|
||
|
# proper type but bad match, with err_msg
|
||
|
with pytest.raises(AssertionError, match="the failure message"):
|
||
|
with raises(TypeError, match="hello", err_msg="the failure message") as cm:
|
||
|
raise TypeError("Bad message")
|
||
|
assert not cm.raised_and_matched
|
||
|
|
||
|
# no raise with default may_pass=False
|
||
|
with pytest.raises(AssertionError, match="Did not raise"):
|
||
|
with raises(TypeError) as cm:
|
||
|
pass
|
||
|
assert not cm.raised_and_matched
|
||
|
|
||
|
# no raise with may_pass=True
|
||
|
with raises(TypeError, match="hello", may_pass=True) as cm:
|
||
|
pass # still OK
|
||
|
assert not cm.raised_and_matched
|
||
|
|
||
|
# Multiple exception types:
|
||
|
with raises((TypeError, ValueError)):
|
||
|
raise TypeError()
|
||
|
with raises((TypeError, ValueError)):
|
||
|
raise ValueError()
|
||
|
with pytest.raises(AssertionError):
|
||
|
with raises((TypeError, ValueError)):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def test_float32_aware_assert_allclose():
|
||
|
# The relative tolerance for float32 inputs is 1e-4
|
||
|
assert_allclose(np.array([1.0 + 2e-5], dtype=np.float32), 1.0)
|
||
|
with pytest.raises(AssertionError):
|
||
|
assert_allclose(np.array([1.0 + 2e-4], dtype=np.float32), 1.0)
|
||
|
|
||
|
# The relative tolerance for other inputs is left to 1e-7 as in
|
||
|
# the original numpy version.
|
||
|
assert_allclose(np.array([1.0 + 2e-8], dtype=np.float64), 1.0)
|
||
|
with pytest.raises(AssertionError):
|
||
|
assert_allclose(np.array([1.0 + 2e-7], dtype=np.float64), 1.0)
|
||
|
|
||
|
# atol is left to 0.0 by default, even for float32
|
||
|
with pytest.raises(AssertionError):
|
||
|
assert_allclose(np.array([1e-5], dtype=np.float32), 0.0)
|
||
|
assert_allclose(np.array([1e-5], dtype=np.float32), 0.0, atol=2e-5)
|
||
|
|
||
|
|
||
|
@pytest.mark.xfail(_IS_WASM, reason="cannot start subprocess")
|
||
|
def test_assert_run_python_script_without_output():
|
||
|
code = "x = 1"
|
||
|
assert_run_python_script_without_output(code)
|
||
|
|
||
|
code = "print('something to stdout')"
|
||
|
with pytest.raises(AssertionError, match="Expected no output"):
|
||
|
assert_run_python_script_without_output(code)
|
||
|
|
||
|
code = "print('something to stdout')"
|
||
|
with pytest.raises(
|
||
|
AssertionError,
|
||
|
match="output was not supposed to match.+got.+something to stdout",
|
||
|
):
|
||
|
assert_run_python_script_without_output(code, pattern="to.+stdout")
|
||
|
|
||
|
code = "\n".join(["import sys", "print('something to stderr', file=sys.stderr)"])
|
||
|
with pytest.raises(
|
||
|
AssertionError,
|
||
|
match="output was not supposed to match.+got.+something to stderr",
|
||
|
):
|
||
|
assert_run_python_script_without_output(code, pattern="to.+stderr")
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"constructor_name",
|
||
|
[
|
||
|
"sparse_csr",
|
||
|
"sparse_csc",
|
||
|
pytest.param(
|
||
|
"sparse_csr_array",
|
||
|
marks=pytest.mark.skipif(
|
||
|
sp_version < parse_version("1.8"),
|
||
|
reason="sparse arrays are available as of scipy 1.8.0",
|
||
|
),
|
||
|
),
|
||
|
pytest.param(
|
||
|
"sparse_csc_array",
|
||
|
marks=pytest.mark.skipif(
|
||
|
sp_version < parse_version("1.8"),
|
||
|
reason="sparse arrays are available as of scipy 1.8.0",
|
||
|
),
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
def test_convert_container_sparse_to_sparse(constructor_name):
|
||
|
"""Non-regression test to check that we can still convert a sparse container
|
||
|
from a given format to another format.
|
||
|
"""
|
||
|
X_sparse = sparse.random(10, 10, density=0.1, format="csr")
|
||
|
_convert_container(X_sparse, constructor_name)
|
||
|
|
||
|
|
||
|
def check_warnings_as_errors(warning_info, warnings_as_errors):
|
||
|
if warning_info.action == "error" and warnings_as_errors:
|
||
|
with pytest.raises(warning_info.category, match=warning_info.message):
|
||
|
warnings.warn(
|
||
|
message=warning_info.message,
|
||
|
category=warning_info.category,
|
||
|
)
|
||
|
if warning_info.action == "ignore":
|
||
|
with warnings.catch_warnings(record=True) as record:
|
||
|
message = warning_info.message
|
||
|
# Special treatment when regex is used
|
||
|
if "Pyarrow" in message:
|
||
|
message = "\nPyarrow will become a required dependency"
|
||
|
|
||
|
warnings.warn(
|
||
|
message=message,
|
||
|
category=warning_info.category,
|
||
|
)
|
||
|
assert len(record) == 0 if warnings_as_errors else 1
|
||
|
if record:
|
||
|
assert str(record[0].message) == message
|
||
|
assert record[0].category == warning_info.category
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("warning_info", _get_warnings_filters_info_list())
|
||
|
def test_sklearn_warnings_as_errors(warning_info):
|
||
|
warnings_as_errors = os.environ.get("SKLEARN_WARNINGS_AS_ERRORS", "0") != "0"
|
||
|
check_warnings_as_errors(warning_info, warnings_as_errors=warnings_as_errors)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("warning_info", _get_warnings_filters_info_list())
|
||
|
def test_turn_warnings_into_errors(warning_info):
|
||
|
with warnings.catch_warnings():
|
||
|
turn_warnings_into_errors()
|
||
|
check_warnings_as_errors(warning_info, warnings_as_errors=True)
|