1840 lines
59 KiB
Python
1840 lines
59 KiB
Python
from functools import partial
|
|
from inspect import signature
|
|
from itertools import chain, permutations, product
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from sklearn._config import config_context
|
|
from sklearn.datasets import make_multilabel_classification
|
|
from sklearn.metrics import (
|
|
accuracy_score,
|
|
average_precision_score,
|
|
balanced_accuracy_score,
|
|
brier_score_loss,
|
|
cohen_kappa_score,
|
|
confusion_matrix,
|
|
coverage_error,
|
|
d2_absolute_error_score,
|
|
d2_pinball_score,
|
|
d2_tweedie_score,
|
|
dcg_score,
|
|
det_curve,
|
|
explained_variance_score,
|
|
f1_score,
|
|
fbeta_score,
|
|
hamming_loss,
|
|
hinge_loss,
|
|
jaccard_score,
|
|
label_ranking_average_precision_score,
|
|
label_ranking_loss,
|
|
log_loss,
|
|
matthews_corrcoef,
|
|
max_error,
|
|
mean_absolute_error,
|
|
mean_absolute_percentage_error,
|
|
mean_gamma_deviance,
|
|
mean_pinball_loss,
|
|
mean_poisson_deviance,
|
|
mean_squared_error,
|
|
mean_tweedie_deviance,
|
|
median_absolute_error,
|
|
multilabel_confusion_matrix,
|
|
ndcg_score,
|
|
precision_recall_curve,
|
|
precision_score,
|
|
r2_score,
|
|
recall_score,
|
|
roc_auc_score,
|
|
roc_curve,
|
|
top_k_accuracy_score,
|
|
zero_one_loss,
|
|
)
|
|
from sklearn.metrics._base import _average_binary_score
|
|
from sklearn.preprocessing import LabelBinarizer
|
|
from sklearn.utils import shuffle
|
|
from sklearn.utils._array_api import (
|
|
_atol_for_type,
|
|
yield_namespace_device_dtype_combinations,
|
|
)
|
|
from sklearn.utils._testing import (
|
|
_array_api_for_tests,
|
|
assert_allclose,
|
|
assert_almost_equal,
|
|
assert_array_equal,
|
|
assert_array_less,
|
|
ignore_warnings,
|
|
)
|
|
from sklearn.utils.fixes import COO_CONTAINERS
|
|
from sklearn.utils.multiclass import type_of_target
|
|
from sklearn.utils.validation import _num_samples, check_random_state
|
|
|
|
# Note toward developers about metric testing
|
|
# -------------------------------------------
|
|
# It is often possible to write one general test for several metrics:
|
|
#
|
|
# - invariance properties, e.g. invariance to sample order
|
|
# - common behavior for an argument, e.g. the "normalize" with value True
|
|
# will return the mean of the metrics and with value False will return
|
|
# the sum of the metrics.
|
|
#
|
|
# In order to improve the overall metric testing, it is a good idea to write
|
|
# first a specific test for the given metric and then add a general test for
|
|
# all metrics that have the same behavior.
|
|
#
|
|
# Two types of datastructures are used in order to implement this system:
|
|
# dictionaries of metrics and lists of metrics with common properties.
|
|
#
|
|
# Dictionaries of metrics
|
|
# ------------------------
|
|
# The goal of having those dictionaries is to have an easy way to call a
|
|
# particular metric and associate a name to each function:
|
|
#
|
|
# - REGRESSION_METRICS: all regression metrics.
|
|
# - CLASSIFICATION_METRICS: all classification metrics
|
|
# which compare a ground truth and the estimated targets as returned by a
|
|
# classifier.
|
|
# - THRESHOLDED_METRICS: all classification metrics which
|
|
# compare a ground truth and a score, e.g. estimated probabilities or
|
|
# decision function (format might vary)
|
|
#
|
|
# Those dictionaries will be used to test systematically some invariance
|
|
# properties, e.g. invariance toward several input layout.
|
|
#
|
|
|
|
REGRESSION_METRICS = {
|
|
"max_error": max_error,
|
|
"mean_absolute_error": mean_absolute_error,
|
|
"mean_squared_error": mean_squared_error,
|
|
"mean_pinball_loss": mean_pinball_loss,
|
|
"median_absolute_error": median_absolute_error,
|
|
"mean_absolute_percentage_error": mean_absolute_percentage_error,
|
|
"explained_variance_score": explained_variance_score,
|
|
"r2_score": partial(r2_score, multioutput="variance_weighted"),
|
|
"mean_normal_deviance": partial(mean_tweedie_deviance, power=0),
|
|
"mean_poisson_deviance": mean_poisson_deviance,
|
|
"mean_gamma_deviance": mean_gamma_deviance,
|
|
"mean_compound_poisson_deviance": partial(mean_tweedie_deviance, power=1.4),
|
|
"d2_tweedie_score": partial(d2_tweedie_score, power=1.4),
|
|
"d2_pinball_score": d2_pinball_score,
|
|
"d2_absolute_error_score": d2_absolute_error_score,
|
|
}
|
|
|
|
CLASSIFICATION_METRICS = {
|
|
"accuracy_score": accuracy_score,
|
|
"balanced_accuracy_score": balanced_accuracy_score,
|
|
"adjusted_balanced_accuracy_score": partial(balanced_accuracy_score, adjusted=True),
|
|
"unnormalized_accuracy_score": partial(accuracy_score, normalize=False),
|
|
# `confusion_matrix` returns absolute values and hence behaves unnormalized
|
|
# . Naming it with an unnormalized_ prefix is necessary for this module to
|
|
# skip sample_weight scaling checks which will fail for unnormalized
|
|
# metrics.
|
|
"unnormalized_confusion_matrix": confusion_matrix,
|
|
"normalized_confusion_matrix": lambda *args, **kwargs: (
|
|
confusion_matrix(*args, **kwargs).astype("float")
|
|
/ confusion_matrix(*args, **kwargs).sum(axis=1)[:, np.newaxis]
|
|
),
|
|
"unnormalized_multilabel_confusion_matrix": multilabel_confusion_matrix,
|
|
"unnormalized_multilabel_confusion_matrix_sample": partial(
|
|
multilabel_confusion_matrix, samplewise=True
|
|
),
|
|
"hamming_loss": hamming_loss,
|
|
"zero_one_loss": zero_one_loss,
|
|
"unnormalized_zero_one_loss": partial(zero_one_loss, normalize=False),
|
|
# These are needed to test averaging
|
|
"jaccard_score": jaccard_score,
|
|
"precision_score": precision_score,
|
|
"recall_score": recall_score,
|
|
"f1_score": f1_score,
|
|
"f2_score": partial(fbeta_score, beta=2),
|
|
"f0.5_score": partial(fbeta_score, beta=0.5),
|
|
"matthews_corrcoef_score": matthews_corrcoef,
|
|
"weighted_f0.5_score": partial(fbeta_score, average="weighted", beta=0.5),
|
|
"weighted_f1_score": partial(f1_score, average="weighted"),
|
|
"weighted_f2_score": partial(fbeta_score, average="weighted", beta=2),
|
|
"weighted_precision_score": partial(precision_score, average="weighted"),
|
|
"weighted_recall_score": partial(recall_score, average="weighted"),
|
|
"weighted_jaccard_score": partial(jaccard_score, average="weighted"),
|
|
"micro_f0.5_score": partial(fbeta_score, average="micro", beta=0.5),
|
|
"micro_f1_score": partial(f1_score, average="micro"),
|
|
"micro_f2_score": partial(fbeta_score, average="micro", beta=2),
|
|
"micro_precision_score": partial(precision_score, average="micro"),
|
|
"micro_recall_score": partial(recall_score, average="micro"),
|
|
"micro_jaccard_score": partial(jaccard_score, average="micro"),
|
|
"macro_f0.5_score": partial(fbeta_score, average="macro", beta=0.5),
|
|
"macro_f1_score": partial(f1_score, average="macro"),
|
|
"macro_f2_score": partial(fbeta_score, average="macro", beta=2),
|
|
"macro_precision_score": partial(precision_score, average="macro"),
|
|
"macro_recall_score": partial(recall_score, average="macro"),
|
|
"macro_jaccard_score": partial(jaccard_score, average="macro"),
|
|
"samples_f0.5_score": partial(fbeta_score, average="samples", beta=0.5),
|
|
"samples_f1_score": partial(f1_score, average="samples"),
|
|
"samples_f2_score": partial(fbeta_score, average="samples", beta=2),
|
|
"samples_precision_score": partial(precision_score, average="samples"),
|
|
"samples_recall_score": partial(recall_score, average="samples"),
|
|
"samples_jaccard_score": partial(jaccard_score, average="samples"),
|
|
"cohen_kappa_score": cohen_kappa_score,
|
|
}
|
|
|
|
|
|
def precision_recall_curve_padded_thresholds(*args, **kwargs):
|
|
"""
|
|
The dimensions of precision-recall pairs and the threshold array as
|
|
returned by the precision_recall_curve do not match. See
|
|
func:`sklearn.metrics.precision_recall_curve`
|
|
|
|
This prevents implicit conversion of return value triple to an higher
|
|
dimensional np.array of dtype('float64') (it will be of dtype('object)
|
|
instead). This again is needed for assert_array_equal to work correctly.
|
|
|
|
As a workaround we pad the threshold array with NaN values to match
|
|
the dimension of precision and recall arrays respectively.
|
|
"""
|
|
precision, recall, thresholds = precision_recall_curve(*args, **kwargs)
|
|
|
|
pad_threshholds = len(precision) - len(thresholds)
|
|
|
|
return np.array(
|
|
[
|
|
precision,
|
|
recall,
|
|
np.pad(
|
|
thresholds.astype(np.float64),
|
|
pad_width=(0, pad_threshholds),
|
|
mode="constant",
|
|
constant_values=[np.nan],
|
|
),
|
|
]
|
|
)
|
|
|
|
|
|
CURVE_METRICS = {
|
|
"roc_curve": roc_curve,
|
|
"precision_recall_curve": precision_recall_curve_padded_thresholds,
|
|
"det_curve": det_curve,
|
|
}
|
|
|
|
THRESHOLDED_METRICS = {
|
|
"coverage_error": coverage_error,
|
|
"label_ranking_loss": label_ranking_loss,
|
|
"log_loss": log_loss,
|
|
"unnormalized_log_loss": partial(log_loss, normalize=False),
|
|
"hinge_loss": hinge_loss,
|
|
"brier_score_loss": brier_score_loss,
|
|
"roc_auc_score": roc_auc_score, # default: average="macro"
|
|
"weighted_roc_auc": partial(roc_auc_score, average="weighted"),
|
|
"samples_roc_auc": partial(roc_auc_score, average="samples"),
|
|
"micro_roc_auc": partial(roc_auc_score, average="micro"),
|
|
"ovr_roc_auc": partial(roc_auc_score, average="macro", multi_class="ovr"),
|
|
"weighted_ovr_roc_auc": partial(
|
|
roc_auc_score, average="weighted", multi_class="ovr"
|
|
),
|
|
"ovo_roc_auc": partial(roc_auc_score, average="macro", multi_class="ovo"),
|
|
"weighted_ovo_roc_auc": partial(
|
|
roc_auc_score, average="weighted", multi_class="ovo"
|
|
),
|
|
"partial_roc_auc": partial(roc_auc_score, max_fpr=0.5),
|
|
"average_precision_score": average_precision_score, # default: average="macro"
|
|
"weighted_average_precision_score": partial(
|
|
average_precision_score, average="weighted"
|
|
),
|
|
"samples_average_precision_score": partial(
|
|
average_precision_score, average="samples"
|
|
),
|
|
"micro_average_precision_score": partial(average_precision_score, average="micro"),
|
|
"label_ranking_average_precision_score": label_ranking_average_precision_score,
|
|
"ndcg_score": ndcg_score,
|
|
"dcg_score": dcg_score,
|
|
"top_k_accuracy_score": top_k_accuracy_score,
|
|
}
|
|
|
|
ALL_METRICS = dict()
|
|
ALL_METRICS.update(THRESHOLDED_METRICS)
|
|
ALL_METRICS.update(CLASSIFICATION_METRICS)
|
|
ALL_METRICS.update(REGRESSION_METRICS)
|
|
ALL_METRICS.update(CURVE_METRICS)
|
|
|
|
# Lists of metrics with common properties
|
|
# ---------------------------------------
|
|
# Lists of metrics with common properties are used to test systematically some
|
|
# functionalities and invariance, e.g. SYMMETRIC_METRICS lists all metrics that
|
|
# are symmetric with respect to their input argument y_true and y_pred.
|
|
#
|
|
# When you add a new metric or functionality, check if a general test
|
|
# is already written.
|
|
|
|
# Those metrics don't support binary inputs
|
|
METRIC_UNDEFINED_BINARY = {
|
|
"samples_f0.5_score",
|
|
"samples_f1_score",
|
|
"samples_f2_score",
|
|
"samples_precision_score",
|
|
"samples_recall_score",
|
|
"samples_jaccard_score",
|
|
"coverage_error",
|
|
"unnormalized_multilabel_confusion_matrix_sample",
|
|
"label_ranking_loss",
|
|
"label_ranking_average_precision_score",
|
|
"dcg_score",
|
|
"ndcg_score",
|
|
}
|
|
|
|
# Those metrics don't support multiclass inputs
|
|
METRIC_UNDEFINED_MULTICLASS = {
|
|
"brier_score_loss",
|
|
"micro_roc_auc",
|
|
"samples_roc_auc",
|
|
"partial_roc_auc",
|
|
"roc_auc_score",
|
|
"weighted_roc_auc",
|
|
"jaccard_score",
|
|
# with default average='binary', multiclass is prohibited
|
|
"precision_score",
|
|
"recall_score",
|
|
"f1_score",
|
|
"f2_score",
|
|
"f0.5_score",
|
|
# curves
|
|
"roc_curve",
|
|
"precision_recall_curve",
|
|
"det_curve",
|
|
}
|
|
|
|
# Metric undefined with "binary" or "multiclass" input
|
|
METRIC_UNDEFINED_BINARY_MULTICLASS = METRIC_UNDEFINED_BINARY.union(
|
|
METRIC_UNDEFINED_MULTICLASS
|
|
)
|
|
|
|
# Metrics with an "average" argument
|
|
METRICS_WITH_AVERAGING = {
|
|
"precision_score",
|
|
"recall_score",
|
|
"f1_score",
|
|
"f2_score",
|
|
"f0.5_score",
|
|
"jaccard_score",
|
|
}
|
|
|
|
# Threshold-based metrics with an "average" argument
|
|
THRESHOLDED_METRICS_WITH_AVERAGING = {
|
|
"roc_auc_score",
|
|
"average_precision_score",
|
|
"partial_roc_auc",
|
|
}
|
|
|
|
# Metrics with a "pos_label" argument
|
|
METRICS_WITH_POS_LABEL = {
|
|
"roc_curve",
|
|
"precision_recall_curve",
|
|
"det_curve",
|
|
"brier_score_loss",
|
|
"precision_score",
|
|
"recall_score",
|
|
"f1_score",
|
|
"f2_score",
|
|
"f0.5_score",
|
|
"jaccard_score",
|
|
"average_precision_score",
|
|
"weighted_average_precision_score",
|
|
"micro_average_precision_score",
|
|
"samples_average_precision_score",
|
|
}
|
|
|
|
# Metrics with a "labels" argument
|
|
# TODO: Handle multi_class metrics that has a labels argument as well as a
|
|
# decision function argument. e.g hinge_loss
|
|
METRICS_WITH_LABELS = {
|
|
"unnormalized_confusion_matrix",
|
|
"normalized_confusion_matrix",
|
|
"roc_curve",
|
|
"precision_recall_curve",
|
|
"det_curve",
|
|
"precision_score",
|
|
"recall_score",
|
|
"f1_score",
|
|
"f2_score",
|
|
"f0.5_score",
|
|
"jaccard_score",
|
|
"weighted_f0.5_score",
|
|
"weighted_f1_score",
|
|
"weighted_f2_score",
|
|
"weighted_precision_score",
|
|
"weighted_recall_score",
|
|
"weighted_jaccard_score",
|
|
"micro_f0.5_score",
|
|
"micro_f1_score",
|
|
"micro_f2_score",
|
|
"micro_precision_score",
|
|
"micro_recall_score",
|
|
"micro_jaccard_score",
|
|
"macro_f0.5_score",
|
|
"macro_f1_score",
|
|
"macro_f2_score",
|
|
"macro_precision_score",
|
|
"macro_recall_score",
|
|
"macro_jaccard_score",
|
|
"unnormalized_multilabel_confusion_matrix",
|
|
"unnormalized_multilabel_confusion_matrix_sample",
|
|
"cohen_kappa_score",
|
|
}
|
|
|
|
# Metrics with a "normalize" option
|
|
METRICS_WITH_NORMALIZE_OPTION = {
|
|
"accuracy_score",
|
|
"top_k_accuracy_score",
|
|
"zero_one_loss",
|
|
}
|
|
|
|
# Threshold-based metrics with "multilabel-indicator" format support
|
|
THRESHOLDED_MULTILABEL_METRICS = {
|
|
"log_loss",
|
|
"unnormalized_log_loss",
|
|
"roc_auc_score",
|
|
"weighted_roc_auc",
|
|
"samples_roc_auc",
|
|
"micro_roc_auc",
|
|
"partial_roc_auc",
|
|
"average_precision_score",
|
|
"weighted_average_precision_score",
|
|
"samples_average_precision_score",
|
|
"micro_average_precision_score",
|
|
"coverage_error",
|
|
"label_ranking_loss",
|
|
"ndcg_score",
|
|
"dcg_score",
|
|
"label_ranking_average_precision_score",
|
|
}
|
|
|
|
# Classification metrics with "multilabel-indicator" format
|
|
MULTILABELS_METRICS = {
|
|
"accuracy_score",
|
|
"unnormalized_accuracy_score",
|
|
"hamming_loss",
|
|
"zero_one_loss",
|
|
"unnormalized_zero_one_loss",
|
|
"weighted_f0.5_score",
|
|
"weighted_f1_score",
|
|
"weighted_f2_score",
|
|
"weighted_precision_score",
|
|
"weighted_recall_score",
|
|
"weighted_jaccard_score",
|
|
"macro_f0.5_score",
|
|
"macro_f1_score",
|
|
"macro_f2_score",
|
|
"macro_precision_score",
|
|
"macro_recall_score",
|
|
"macro_jaccard_score",
|
|
"micro_f0.5_score",
|
|
"micro_f1_score",
|
|
"micro_f2_score",
|
|
"micro_precision_score",
|
|
"micro_recall_score",
|
|
"micro_jaccard_score",
|
|
"unnormalized_multilabel_confusion_matrix",
|
|
"samples_f0.5_score",
|
|
"samples_f1_score",
|
|
"samples_f2_score",
|
|
"samples_precision_score",
|
|
"samples_recall_score",
|
|
"samples_jaccard_score",
|
|
}
|
|
|
|
# Regression metrics with "multioutput-continuous" format support
|
|
MULTIOUTPUT_METRICS = {
|
|
"mean_absolute_error",
|
|
"median_absolute_error",
|
|
"mean_squared_error",
|
|
"r2_score",
|
|
"explained_variance_score",
|
|
"mean_absolute_percentage_error",
|
|
"mean_pinball_loss",
|
|
"d2_pinball_score",
|
|
"d2_absolute_error_score",
|
|
}
|
|
|
|
# Symmetric with respect to their input arguments y_true and y_pred
|
|
# metric(y_true, y_pred) == metric(y_pred, y_true).
|
|
SYMMETRIC_METRICS = {
|
|
"accuracy_score",
|
|
"unnormalized_accuracy_score",
|
|
"hamming_loss",
|
|
"zero_one_loss",
|
|
"unnormalized_zero_one_loss",
|
|
"micro_jaccard_score",
|
|
"macro_jaccard_score",
|
|
"jaccard_score",
|
|
"samples_jaccard_score",
|
|
"f1_score",
|
|
"micro_f1_score",
|
|
"macro_f1_score",
|
|
"weighted_recall_score",
|
|
# P = R = F = accuracy in multiclass case
|
|
"micro_f0.5_score",
|
|
"micro_f1_score",
|
|
"micro_f2_score",
|
|
"micro_precision_score",
|
|
"micro_recall_score",
|
|
"matthews_corrcoef_score",
|
|
"mean_absolute_error",
|
|
"mean_squared_error",
|
|
"median_absolute_error",
|
|
"max_error",
|
|
# Pinball loss is only symmetric for alpha=0.5 which is the default.
|
|
"mean_pinball_loss",
|
|
"cohen_kappa_score",
|
|
"mean_normal_deviance",
|
|
}
|
|
|
|
# Asymmetric with respect to their input arguments y_true and y_pred
|
|
# metric(y_true, y_pred) != metric(y_pred, y_true).
|
|
NOT_SYMMETRIC_METRICS = {
|
|
"balanced_accuracy_score",
|
|
"adjusted_balanced_accuracy_score",
|
|
"explained_variance_score",
|
|
"r2_score",
|
|
"unnormalized_confusion_matrix",
|
|
"normalized_confusion_matrix",
|
|
"roc_curve",
|
|
"precision_recall_curve",
|
|
"det_curve",
|
|
"precision_score",
|
|
"recall_score",
|
|
"f2_score",
|
|
"f0.5_score",
|
|
"weighted_f0.5_score",
|
|
"weighted_f1_score",
|
|
"weighted_f2_score",
|
|
"weighted_precision_score",
|
|
"weighted_jaccard_score",
|
|
"unnormalized_multilabel_confusion_matrix",
|
|
"macro_f0.5_score",
|
|
"macro_f2_score",
|
|
"macro_precision_score",
|
|
"macro_recall_score",
|
|
"hinge_loss",
|
|
"mean_gamma_deviance",
|
|
"mean_poisson_deviance",
|
|
"mean_compound_poisson_deviance",
|
|
"d2_tweedie_score",
|
|
"d2_pinball_score",
|
|
"d2_absolute_error_score",
|
|
"mean_absolute_percentage_error",
|
|
}
|
|
|
|
|
|
# No Sample weight support
|
|
METRICS_WITHOUT_SAMPLE_WEIGHT = {
|
|
"median_absolute_error",
|
|
"max_error",
|
|
"ovo_roc_auc",
|
|
"weighted_ovo_roc_auc",
|
|
}
|
|
|
|
METRICS_REQUIRE_POSITIVE_Y = {
|
|
"mean_poisson_deviance",
|
|
"mean_gamma_deviance",
|
|
"mean_compound_poisson_deviance",
|
|
"d2_tweedie_score",
|
|
}
|
|
|
|
|
|
def _require_positive_targets(y1, y2):
|
|
"""Make targets strictly positive"""
|
|
offset = abs(min(y1.min(), y2.min())) + 1
|
|
y1 += offset
|
|
y2 += offset
|
|
return y1, y2
|
|
|
|
|
|
def test_symmetry_consistency():
|
|
# We shouldn't forget any metrics
|
|
assert (
|
|
SYMMETRIC_METRICS
|
|
| NOT_SYMMETRIC_METRICS
|
|
| set(THRESHOLDED_METRICS)
|
|
| METRIC_UNDEFINED_BINARY_MULTICLASS
|
|
) == set(ALL_METRICS)
|
|
|
|
assert (SYMMETRIC_METRICS & NOT_SYMMETRIC_METRICS) == set()
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(SYMMETRIC_METRICS))
|
|
def test_symmetric_metric(name):
|
|
# Test the symmetry of score and loss functions
|
|
random_state = check_random_state(0)
|
|
y_true = random_state.randint(0, 2, size=(20,))
|
|
y_pred = random_state.randint(0, 2, size=(20,))
|
|
|
|
if name in METRICS_REQUIRE_POSITIVE_Y:
|
|
y_true, y_pred = _require_positive_targets(y_true, y_pred)
|
|
|
|
y_true_bin = random_state.randint(0, 2, size=(20, 25))
|
|
y_pred_bin = random_state.randint(0, 2, size=(20, 25))
|
|
|
|
metric = ALL_METRICS[name]
|
|
if name in METRIC_UNDEFINED_BINARY:
|
|
if name in MULTILABELS_METRICS:
|
|
assert_allclose(
|
|
metric(y_true_bin, y_pred_bin),
|
|
metric(y_pred_bin, y_true_bin),
|
|
err_msg="%s is not symmetric" % name,
|
|
)
|
|
else:
|
|
assert False, "This case is currently unhandled"
|
|
else:
|
|
assert_allclose(
|
|
metric(y_true, y_pred),
|
|
metric(y_pred, y_true),
|
|
err_msg="%s is not symmetric" % name,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(NOT_SYMMETRIC_METRICS))
|
|
def test_not_symmetric_metric(name):
|
|
# Test the symmetry of score and loss functions
|
|
random_state = check_random_state(0)
|
|
y_true = random_state.randint(0, 2, size=(20,))
|
|
y_pred = random_state.randint(0, 2, size=(20,))
|
|
|
|
if name in METRICS_REQUIRE_POSITIVE_Y:
|
|
y_true, y_pred = _require_positive_targets(y_true, y_pred)
|
|
|
|
metric = ALL_METRICS[name]
|
|
|
|
# use context manager to supply custom error message
|
|
with pytest.raises(AssertionError):
|
|
assert_array_equal(metric(y_true, y_pred), metric(y_pred, y_true))
|
|
raise ValueError("%s seems to be symmetric" % name)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", sorted(set(ALL_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS)
|
|
)
|
|
def test_sample_order_invariance(name):
|
|
random_state = check_random_state(0)
|
|
y_true = random_state.randint(0, 2, size=(20,))
|
|
y_pred = random_state.randint(0, 2, size=(20,))
|
|
|
|
if name in METRICS_REQUIRE_POSITIVE_Y:
|
|
y_true, y_pred = _require_positive_targets(y_true, y_pred)
|
|
|
|
y_true_shuffle, y_pred_shuffle = shuffle(y_true, y_pred, random_state=0)
|
|
|
|
with ignore_warnings():
|
|
metric = ALL_METRICS[name]
|
|
assert_allclose(
|
|
metric(y_true, y_pred),
|
|
metric(y_true_shuffle, y_pred_shuffle),
|
|
err_msg="%s is not sample order invariant" % name,
|
|
)
|
|
|
|
|
|
@ignore_warnings
|
|
def test_sample_order_invariance_multilabel_and_multioutput():
|
|
random_state = check_random_state(0)
|
|
|
|
# Generate some data
|
|
y_true = random_state.randint(0, 2, size=(20, 25))
|
|
y_pred = random_state.randint(0, 2, size=(20, 25))
|
|
y_score = random_state.normal(size=y_true.shape)
|
|
|
|
y_true_shuffle, y_pred_shuffle, y_score_shuffle = shuffle(
|
|
y_true, y_pred, y_score, random_state=0
|
|
)
|
|
|
|
for name in MULTILABELS_METRICS:
|
|
metric = ALL_METRICS[name]
|
|
assert_allclose(
|
|
metric(y_true, y_pred),
|
|
metric(y_true_shuffle, y_pred_shuffle),
|
|
err_msg="%s is not sample order invariant" % name,
|
|
)
|
|
|
|
for name in THRESHOLDED_MULTILABEL_METRICS:
|
|
metric = ALL_METRICS[name]
|
|
assert_allclose(
|
|
metric(y_true, y_score),
|
|
metric(y_true_shuffle, y_score_shuffle),
|
|
err_msg="%s is not sample order invariant" % name,
|
|
)
|
|
|
|
for name in MULTIOUTPUT_METRICS:
|
|
metric = ALL_METRICS[name]
|
|
assert_allclose(
|
|
metric(y_true, y_score),
|
|
metric(y_true_shuffle, y_score_shuffle),
|
|
err_msg="%s is not sample order invariant" % name,
|
|
)
|
|
assert_allclose(
|
|
metric(y_true, y_pred),
|
|
metric(y_true_shuffle, y_pred_shuffle),
|
|
err_msg="%s is not sample order invariant" % name,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", sorted(set(ALL_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS)
|
|
)
|
|
def test_format_invariance_with_1d_vectors(name):
|
|
random_state = check_random_state(0)
|
|
y1 = random_state.randint(0, 2, size=(20,))
|
|
y2 = random_state.randint(0, 2, size=(20,))
|
|
|
|
if name in METRICS_REQUIRE_POSITIVE_Y:
|
|
y1, y2 = _require_positive_targets(y1, y2)
|
|
|
|
y1_list = list(y1)
|
|
y2_list = list(y2)
|
|
|
|
y1_1d, y2_1d = np.array(y1), np.array(y2)
|
|
assert_array_equal(y1_1d.ndim, 1)
|
|
assert_array_equal(y2_1d.ndim, 1)
|
|
y1_column = np.reshape(y1_1d, (-1, 1))
|
|
y2_column = np.reshape(y2_1d, (-1, 1))
|
|
y1_row = np.reshape(y1_1d, (1, -1))
|
|
y2_row = np.reshape(y2_1d, (1, -1))
|
|
|
|
with ignore_warnings():
|
|
metric = ALL_METRICS[name]
|
|
|
|
measure = metric(y1, y2)
|
|
|
|
assert_allclose(
|
|
metric(y1_list, y2_list),
|
|
measure,
|
|
err_msg="%s is not representation invariant with list" % name,
|
|
)
|
|
|
|
assert_allclose(
|
|
metric(y1_1d, y2_1d),
|
|
measure,
|
|
err_msg="%s is not representation invariant with np-array-1d" % name,
|
|
)
|
|
|
|
assert_allclose(
|
|
metric(y1_column, y2_column),
|
|
measure,
|
|
err_msg="%s is not representation invariant with np-array-column" % name,
|
|
)
|
|
|
|
# Mix format support
|
|
assert_allclose(
|
|
metric(y1_1d, y2_list),
|
|
measure,
|
|
err_msg="%s is not representation invariant with mix np-array-1d and list"
|
|
% name,
|
|
)
|
|
|
|
assert_allclose(
|
|
metric(y1_list, y2_1d),
|
|
measure,
|
|
err_msg="%s is not representation invariant with mix np-array-1d and list"
|
|
% name,
|
|
)
|
|
|
|
assert_allclose(
|
|
metric(y1_1d, y2_column),
|
|
measure,
|
|
err_msg=(
|
|
"%s is not representation invariant with mix "
|
|
"np-array-1d and np-array-column"
|
|
)
|
|
% name,
|
|
)
|
|
|
|
assert_allclose(
|
|
metric(y1_column, y2_1d),
|
|
measure,
|
|
err_msg=(
|
|
"%s is not representation invariant with mix "
|
|
"np-array-1d and np-array-column"
|
|
)
|
|
% name,
|
|
)
|
|
|
|
assert_allclose(
|
|
metric(y1_list, y2_column),
|
|
measure,
|
|
err_msg=(
|
|
"%s is not representation invariant with mix list and np-array-column"
|
|
)
|
|
% name,
|
|
)
|
|
|
|
assert_allclose(
|
|
metric(y1_column, y2_list),
|
|
measure,
|
|
err_msg=(
|
|
"%s is not representation invariant with mix list and np-array-column"
|
|
)
|
|
% name,
|
|
)
|
|
|
|
# These mix representations aren't allowed
|
|
with pytest.raises(ValueError):
|
|
metric(y1_1d, y2_row)
|
|
with pytest.raises(ValueError):
|
|
metric(y1_row, y2_1d)
|
|
with pytest.raises(ValueError):
|
|
metric(y1_list, y2_row)
|
|
with pytest.raises(ValueError):
|
|
metric(y1_row, y2_list)
|
|
with pytest.raises(ValueError):
|
|
metric(y1_column, y2_row)
|
|
with pytest.raises(ValueError):
|
|
metric(y1_row, y2_column)
|
|
|
|
# NB: We do not test for y1_row, y2_row as these may be
|
|
# interpreted as multilabel or multioutput data.
|
|
if name not in (
|
|
MULTIOUTPUT_METRICS | THRESHOLDED_MULTILABEL_METRICS | MULTILABELS_METRICS
|
|
):
|
|
with pytest.raises(ValueError):
|
|
metric(y1_row, y2_row)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", sorted(set(CLASSIFICATION_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS)
|
|
)
|
|
def test_classification_invariance_string_vs_numbers_labels(name):
|
|
# Ensure that classification metrics with string labels are invariant
|
|
random_state = check_random_state(0)
|
|
y1 = random_state.randint(0, 2, size=(20,))
|
|
y2 = random_state.randint(0, 2, size=(20,))
|
|
|
|
y1_str = np.array(["eggs", "spam"])[y1]
|
|
y2_str = np.array(["eggs", "spam"])[y2]
|
|
|
|
pos_label_str = "spam"
|
|
labels_str = ["eggs", "spam"]
|
|
|
|
with ignore_warnings():
|
|
metric = CLASSIFICATION_METRICS[name]
|
|
measure_with_number = metric(y1, y2)
|
|
|
|
# Ugly, but handle case with a pos_label and label
|
|
metric_str = metric
|
|
if name in METRICS_WITH_POS_LABEL:
|
|
metric_str = partial(metric_str, pos_label=pos_label_str)
|
|
|
|
measure_with_str = metric_str(y1_str, y2_str)
|
|
|
|
assert_array_equal(
|
|
measure_with_number,
|
|
measure_with_str,
|
|
err_msg="{0} failed string vs number invariance test".format(name),
|
|
)
|
|
|
|
measure_with_strobj = metric_str(y1_str.astype("O"), y2_str.astype("O"))
|
|
assert_array_equal(
|
|
measure_with_number,
|
|
measure_with_strobj,
|
|
err_msg="{0} failed string object vs number invariance test".format(name),
|
|
)
|
|
|
|
if name in METRICS_WITH_LABELS:
|
|
metric_str = partial(metric_str, labels=labels_str)
|
|
measure_with_str = metric_str(y1_str, y2_str)
|
|
assert_array_equal(
|
|
measure_with_number,
|
|
measure_with_str,
|
|
err_msg="{0} failed string vs number invariance test".format(name),
|
|
)
|
|
|
|
measure_with_strobj = metric_str(y1_str.astype("O"), y2_str.astype("O"))
|
|
assert_array_equal(
|
|
measure_with_number,
|
|
measure_with_strobj,
|
|
err_msg="{0} failed string vs number invariance test".format(name),
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("name", THRESHOLDED_METRICS)
|
|
def test_thresholded_invariance_string_vs_numbers_labels(name):
|
|
# Ensure that thresholded metrics with string labels are invariant
|
|
random_state = check_random_state(0)
|
|
y1 = random_state.randint(0, 2, size=(20,))
|
|
y2 = random_state.randint(0, 2, size=(20,))
|
|
|
|
y1_str = np.array(["eggs", "spam"])[y1]
|
|
|
|
pos_label_str = "spam"
|
|
|
|
with ignore_warnings():
|
|
metric = THRESHOLDED_METRICS[name]
|
|
if name not in METRIC_UNDEFINED_BINARY:
|
|
# Ugly, but handle case with a pos_label and label
|
|
metric_str = metric
|
|
if name in METRICS_WITH_POS_LABEL:
|
|
metric_str = partial(metric_str, pos_label=pos_label_str)
|
|
|
|
measure_with_number = metric(y1, y2)
|
|
measure_with_str = metric_str(y1_str, y2)
|
|
assert_array_equal(
|
|
measure_with_number,
|
|
measure_with_str,
|
|
err_msg="{0} failed string vs number invariance test".format(name),
|
|
)
|
|
|
|
measure_with_strobj = metric_str(y1_str.astype("O"), y2)
|
|
assert_array_equal(
|
|
measure_with_number,
|
|
measure_with_strobj,
|
|
err_msg="{0} failed string object vs number invariance test".format(
|
|
name
|
|
),
|
|
)
|
|
else:
|
|
# TODO those metrics doesn't support string label yet
|
|
with pytest.raises(ValueError):
|
|
metric(y1_str, y2)
|
|
with pytest.raises(ValueError):
|
|
metric(y1_str.astype("O"), y2)
|
|
|
|
|
|
invalids_nan_inf = [
|
|
([0, 1], [np.inf, np.inf]),
|
|
([0, 1], [np.nan, np.nan]),
|
|
([0, 1], [np.nan, np.inf]),
|
|
([0, 1], [np.inf, 1]),
|
|
([0, 1], [np.nan, 1]),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"metric", chain(THRESHOLDED_METRICS.values(), REGRESSION_METRICS.values())
|
|
)
|
|
@pytest.mark.parametrize("y_true, y_score", invalids_nan_inf)
|
|
def test_regression_thresholded_inf_nan_input(metric, y_true, y_score):
|
|
# Reshape since coverage_error only accepts 2D arrays.
|
|
if metric == coverage_error:
|
|
y_true = [y_true]
|
|
y_score = [y_score]
|
|
with pytest.raises(ValueError, match=r"contains (NaN|infinity)"):
|
|
metric(y_true, y_score)
|
|
|
|
|
|
@pytest.mark.parametrize("metric", CLASSIFICATION_METRICS.values())
|
|
@pytest.mark.parametrize(
|
|
"y_true, y_score",
|
|
invalids_nan_inf +
|
|
# Add an additional case for classification only
|
|
# non-regression test for:
|
|
# https://github.com/scikit-learn/scikit-learn/issues/6809
|
|
[
|
|
([np.nan, 1, 2], [1, 2, 3]),
|
|
([np.inf, 1, 2], [1, 2, 3]),
|
|
], # type: ignore
|
|
)
|
|
def test_classification_inf_nan_input(metric, y_true, y_score):
|
|
"""check that classification metrics raise a message mentioning the
|
|
occurrence of non-finite values in the target vectors."""
|
|
if not np.isfinite(y_true).all():
|
|
input_name = "y_true"
|
|
if np.isnan(y_true).any():
|
|
unexpected_value = "NaN"
|
|
else:
|
|
unexpected_value = "infinity or a value too large"
|
|
else:
|
|
input_name = "y_pred"
|
|
if np.isnan(y_score).any():
|
|
unexpected_value = "NaN"
|
|
else:
|
|
unexpected_value = "infinity or a value too large"
|
|
|
|
err_msg = f"Input {input_name} contains {unexpected_value}"
|
|
|
|
with pytest.raises(ValueError, match=err_msg):
|
|
metric(y_true, y_score)
|
|
|
|
|
|
@pytest.mark.parametrize("metric", CLASSIFICATION_METRICS.values())
|
|
def test_classification_binary_continuous_input(metric):
|
|
"""check that classification metrics raise a message of mixed type data
|
|
with continuous/binary target vectors."""
|
|
y_true, y_score = ["a", "b", "a"], [0.1, 0.2, 0.3]
|
|
err_msg = (
|
|
"Classification metrics can't handle a mix of binary and continuous targets"
|
|
)
|
|
with pytest.raises(ValueError, match=err_msg):
|
|
metric(y_true, y_score)
|
|
|
|
|
|
@ignore_warnings
|
|
def check_single_sample(name):
|
|
# Non-regression test: scores should work with a single sample.
|
|
# This is important for leave-one-out cross validation.
|
|
# Score functions tested are those that formerly called np.squeeze,
|
|
# which turns an array of size 1 into a 0-d array (!).
|
|
metric = ALL_METRICS[name]
|
|
|
|
# assert that no exception is thrown
|
|
if name in METRICS_REQUIRE_POSITIVE_Y:
|
|
values = [1, 2]
|
|
else:
|
|
values = [0, 1]
|
|
for i, j in product(values, repeat=2):
|
|
metric([i], [j])
|
|
|
|
|
|
@ignore_warnings
|
|
def check_single_sample_multioutput(name):
|
|
metric = ALL_METRICS[name]
|
|
for i, j, k, l in product([0, 1], repeat=4):
|
|
metric(np.array([[i, j]]), np.array([[k, l]]))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name",
|
|
sorted(
|
|
set(ALL_METRICS)
|
|
# Those metrics are not always defined with one sample
|
|
# or in multiclass classification
|
|
- METRIC_UNDEFINED_BINARY_MULTICLASS
|
|
- set(THRESHOLDED_METRICS)
|
|
),
|
|
)
|
|
def test_single_sample(name):
|
|
check_single_sample(name)
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(MULTIOUTPUT_METRICS | MULTILABELS_METRICS))
|
|
def test_single_sample_multioutput(name):
|
|
check_single_sample_multioutput(name)
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(MULTIOUTPUT_METRICS))
|
|
def test_multioutput_number_of_output_differ(name):
|
|
y_true = np.array([[1, 0, 0, 1], [0, 1, 1, 1], [1, 1, 0, 1]])
|
|
y_pred = np.array([[0, 0], [1, 0], [0, 0]])
|
|
|
|
metric = ALL_METRICS[name]
|
|
with pytest.raises(ValueError):
|
|
metric(y_true, y_pred)
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(MULTIOUTPUT_METRICS))
|
|
def test_multioutput_regression_invariance_to_dimension_shuffling(name):
|
|
# test invariance to dimension shuffling
|
|
random_state = check_random_state(0)
|
|
y_true = random_state.uniform(0, 2, size=(20, 5))
|
|
y_pred = random_state.uniform(0, 2, size=(20, 5))
|
|
|
|
metric = ALL_METRICS[name]
|
|
error = metric(y_true, y_pred)
|
|
|
|
for _ in range(3):
|
|
perm = random_state.permutation(y_true.shape[1])
|
|
assert_allclose(
|
|
metric(y_true[:, perm], y_pred[:, perm]),
|
|
error,
|
|
err_msg="%s is not dimension shuffling invariant" % (name),
|
|
)
|
|
|
|
|
|
@ignore_warnings
|
|
@pytest.mark.parametrize("coo_container", COO_CONTAINERS)
|
|
def test_multilabel_representation_invariance(coo_container):
|
|
# Generate some data
|
|
n_classes = 4
|
|
n_samples = 50
|
|
|
|
_, y1 = make_multilabel_classification(
|
|
n_features=1,
|
|
n_classes=n_classes,
|
|
random_state=0,
|
|
n_samples=n_samples,
|
|
allow_unlabeled=True,
|
|
)
|
|
_, y2 = make_multilabel_classification(
|
|
n_features=1,
|
|
n_classes=n_classes,
|
|
random_state=1,
|
|
n_samples=n_samples,
|
|
allow_unlabeled=True,
|
|
)
|
|
|
|
# To make sure at least one empty label is present
|
|
y1 = np.vstack([y1, [[0] * n_classes]])
|
|
y2 = np.vstack([y2, [[0] * n_classes]])
|
|
|
|
y1_sparse_indicator = coo_container(y1)
|
|
y2_sparse_indicator = coo_container(y2)
|
|
|
|
y1_list_array_indicator = list(y1)
|
|
y2_list_array_indicator = list(y2)
|
|
|
|
y1_list_list_indicator = [list(a) for a in y1_list_array_indicator]
|
|
y2_list_list_indicator = [list(a) for a in y2_list_array_indicator]
|
|
|
|
for name in MULTILABELS_METRICS:
|
|
metric = ALL_METRICS[name]
|
|
|
|
# XXX cruel hack to work with partial functions
|
|
if isinstance(metric, partial):
|
|
metric.__module__ = "tmp"
|
|
metric.__name__ = name
|
|
|
|
measure = metric(y1, y2)
|
|
|
|
# Check representation invariance
|
|
assert_allclose(
|
|
metric(y1_sparse_indicator, y2_sparse_indicator),
|
|
measure,
|
|
err_msg=(
|
|
"%s failed representation invariance between "
|
|
"dense and sparse indicator formats."
|
|
)
|
|
% name,
|
|
)
|
|
assert_almost_equal(
|
|
metric(y1_list_list_indicator, y2_list_list_indicator),
|
|
measure,
|
|
err_msg=(
|
|
"%s failed representation invariance "
|
|
"between dense array and list of list "
|
|
"indicator formats."
|
|
)
|
|
% name,
|
|
)
|
|
assert_almost_equal(
|
|
metric(y1_list_array_indicator, y2_list_array_indicator),
|
|
measure,
|
|
err_msg=(
|
|
"%s failed representation invariance "
|
|
"between dense and list of array "
|
|
"indicator formats."
|
|
)
|
|
% name,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(MULTILABELS_METRICS))
|
|
def test_raise_value_error_multilabel_sequences(name):
|
|
# make sure the multilabel-sequence format raises ValueError
|
|
multilabel_sequences = [
|
|
[[1], [2], [0, 1]],
|
|
[(), (2), (0, 1)],
|
|
[[]],
|
|
[()],
|
|
np.array([[], [1, 2]], dtype="object"),
|
|
]
|
|
|
|
metric = ALL_METRICS[name]
|
|
for seq in multilabel_sequences:
|
|
with pytest.raises(ValueError):
|
|
metric(seq, seq)
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(METRICS_WITH_NORMALIZE_OPTION))
|
|
def test_normalize_option_binary_classification(name):
|
|
# Test in the binary case
|
|
n_classes = 2
|
|
n_samples = 20
|
|
random_state = check_random_state(0)
|
|
|
|
y_true = random_state.randint(0, n_classes, size=(n_samples,))
|
|
y_pred = random_state.randint(0, n_classes, size=(n_samples,))
|
|
y_score = random_state.normal(size=y_true.shape)
|
|
|
|
metrics = ALL_METRICS[name]
|
|
pred = y_score if name in THRESHOLDED_METRICS else y_pred
|
|
measure_normalized = metrics(y_true, pred, normalize=True)
|
|
measure_not_normalized = metrics(y_true, pred, normalize=False)
|
|
|
|
assert_array_less(
|
|
-1.0 * measure_normalized,
|
|
0,
|
|
err_msg="We failed to test correctly the normalize option",
|
|
)
|
|
|
|
assert_allclose(
|
|
measure_normalized,
|
|
measure_not_normalized / n_samples,
|
|
err_msg=f"Failed with {name}",
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(METRICS_WITH_NORMALIZE_OPTION))
|
|
def test_normalize_option_multiclass_classification(name):
|
|
# Test in the multiclass case
|
|
n_classes = 4
|
|
n_samples = 20
|
|
random_state = check_random_state(0)
|
|
|
|
y_true = random_state.randint(0, n_classes, size=(n_samples,))
|
|
y_pred = random_state.randint(0, n_classes, size=(n_samples,))
|
|
y_score = random_state.uniform(size=(n_samples, n_classes))
|
|
|
|
metrics = ALL_METRICS[name]
|
|
pred = y_score if name in THRESHOLDED_METRICS else y_pred
|
|
measure_normalized = metrics(y_true, pred, normalize=True)
|
|
measure_not_normalized = metrics(y_true, pred, normalize=False)
|
|
|
|
assert_array_less(
|
|
-1.0 * measure_normalized,
|
|
0,
|
|
err_msg="We failed to test correctly the normalize option",
|
|
)
|
|
|
|
assert_allclose(
|
|
measure_normalized,
|
|
measure_not_normalized / n_samples,
|
|
err_msg=f"Failed with {name}",
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", sorted(METRICS_WITH_NORMALIZE_OPTION.intersection(MULTILABELS_METRICS))
|
|
)
|
|
def test_normalize_option_multilabel_classification(name):
|
|
# Test in the multilabel case
|
|
n_classes = 4
|
|
n_samples = 100
|
|
random_state = check_random_state(0)
|
|
|
|
# for both random_state 0 and 1, y_true and y_pred has at least one
|
|
# unlabelled entry
|
|
_, y_true = make_multilabel_classification(
|
|
n_features=1,
|
|
n_classes=n_classes,
|
|
random_state=0,
|
|
allow_unlabeled=True,
|
|
n_samples=n_samples,
|
|
)
|
|
_, y_pred = make_multilabel_classification(
|
|
n_features=1,
|
|
n_classes=n_classes,
|
|
random_state=1,
|
|
allow_unlabeled=True,
|
|
n_samples=n_samples,
|
|
)
|
|
|
|
y_score = random_state.uniform(size=y_true.shape)
|
|
|
|
# To make sure at least one empty label is present
|
|
y_true += [0] * n_classes
|
|
y_pred += [0] * n_classes
|
|
|
|
metrics = ALL_METRICS[name]
|
|
pred = y_score if name in THRESHOLDED_METRICS else y_pred
|
|
measure_normalized = metrics(y_true, pred, normalize=True)
|
|
measure_not_normalized = metrics(y_true, pred, normalize=False)
|
|
|
|
assert_array_less(
|
|
-1.0 * measure_normalized,
|
|
0,
|
|
err_msg="We failed to test correctly the normalize option",
|
|
)
|
|
|
|
assert_allclose(
|
|
measure_normalized,
|
|
measure_not_normalized / n_samples,
|
|
err_msg=f"Failed with {name}",
|
|
)
|
|
|
|
|
|
@ignore_warnings
|
|
def _check_averaging(
|
|
metric, y_true, y_pred, y_true_binarize, y_pred_binarize, is_multilabel
|
|
):
|
|
n_samples, n_classes = y_true_binarize.shape
|
|
|
|
# No averaging
|
|
label_measure = metric(y_true, y_pred, average=None)
|
|
assert_allclose(
|
|
label_measure,
|
|
[
|
|
metric(y_true_binarize[:, i], y_pred_binarize[:, i])
|
|
for i in range(n_classes)
|
|
],
|
|
)
|
|
|
|
# Micro measure
|
|
micro_measure = metric(y_true, y_pred, average="micro")
|
|
assert_allclose(
|
|
micro_measure, metric(y_true_binarize.ravel(), y_pred_binarize.ravel())
|
|
)
|
|
|
|
# Macro measure
|
|
macro_measure = metric(y_true, y_pred, average="macro")
|
|
assert_allclose(macro_measure, np.mean(label_measure))
|
|
|
|
# Weighted measure
|
|
weights = np.sum(y_true_binarize, axis=0, dtype=int)
|
|
|
|
if np.sum(weights) != 0:
|
|
weighted_measure = metric(y_true, y_pred, average="weighted")
|
|
assert_allclose(weighted_measure, np.average(label_measure, weights=weights))
|
|
else:
|
|
weighted_measure = metric(y_true, y_pred, average="weighted")
|
|
assert_allclose(weighted_measure, 0)
|
|
|
|
# Sample measure
|
|
if is_multilabel:
|
|
sample_measure = metric(y_true, y_pred, average="samples")
|
|
assert_allclose(
|
|
sample_measure,
|
|
np.mean(
|
|
[
|
|
metric(y_true_binarize[i], y_pred_binarize[i])
|
|
for i in range(n_samples)
|
|
]
|
|
),
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
metric(y_true, y_pred, average="unknown")
|
|
with pytest.raises(ValueError):
|
|
metric(y_true, y_pred, average="garbage")
|
|
|
|
|
|
def check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score):
|
|
is_multilabel = type_of_target(y_true).startswith("multilabel")
|
|
|
|
metric = ALL_METRICS[name]
|
|
|
|
if name in METRICS_WITH_AVERAGING:
|
|
_check_averaging(
|
|
metric, y_true, y_pred, y_true_binarize, y_pred_binarize, is_multilabel
|
|
)
|
|
elif name in THRESHOLDED_METRICS_WITH_AVERAGING:
|
|
_check_averaging(
|
|
metric, y_true, y_score, y_true_binarize, y_score, is_multilabel
|
|
)
|
|
else:
|
|
raise ValueError("Metric is not recorded as having an average option")
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(METRICS_WITH_AVERAGING))
|
|
def test_averaging_multiclass(name):
|
|
n_samples, n_classes = 50, 3
|
|
random_state = check_random_state(0)
|
|
y_true = random_state.randint(0, n_classes, size=(n_samples,))
|
|
y_pred = random_state.randint(0, n_classes, size=(n_samples,))
|
|
y_score = random_state.uniform(size=(n_samples, n_classes))
|
|
|
|
lb = LabelBinarizer().fit(y_true)
|
|
y_true_binarize = lb.transform(y_true)
|
|
y_pred_binarize = lb.transform(y_pred)
|
|
|
|
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", sorted(METRICS_WITH_AVERAGING | THRESHOLDED_METRICS_WITH_AVERAGING)
|
|
)
|
|
def test_averaging_multilabel(name):
|
|
n_samples, n_classes = 40, 5
|
|
_, y = make_multilabel_classification(
|
|
n_features=1,
|
|
n_classes=n_classes,
|
|
random_state=5,
|
|
n_samples=n_samples,
|
|
allow_unlabeled=False,
|
|
)
|
|
y_true = y[:20]
|
|
y_pred = y[20:]
|
|
y_score = check_random_state(0).normal(size=(20, n_classes))
|
|
y_true_binarize = y_true
|
|
y_pred_binarize = y_pred
|
|
|
|
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score)
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(METRICS_WITH_AVERAGING))
|
|
def test_averaging_multilabel_all_zeroes(name):
|
|
y_true = np.zeros((20, 3))
|
|
y_pred = np.zeros((20, 3))
|
|
y_score = np.zeros((20, 3))
|
|
y_true_binarize = y_true
|
|
y_pred_binarize = y_pred
|
|
|
|
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score)
|
|
|
|
|
|
def test_averaging_binary_multilabel_all_zeroes():
|
|
y_true = np.zeros((20, 3))
|
|
y_pred = np.zeros((20, 3))
|
|
y_true_binarize = y_true
|
|
y_pred_binarize = y_pred
|
|
# Test _average_binary_score for weight.sum() == 0
|
|
binary_metric = lambda y_true, y_score, average="macro": _average_binary_score(
|
|
precision_score, y_true, y_score, average
|
|
)
|
|
_check_averaging(
|
|
binary_metric,
|
|
y_true,
|
|
y_pred,
|
|
y_true_binarize,
|
|
y_pred_binarize,
|
|
is_multilabel=True,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(METRICS_WITH_AVERAGING))
|
|
def test_averaging_multilabel_all_ones(name):
|
|
y_true = np.ones((20, 3))
|
|
y_pred = np.ones((20, 3))
|
|
y_score = np.ones((20, 3))
|
|
y_true_binarize = y_true
|
|
y_pred_binarize = y_pred
|
|
|
|
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score)
|
|
|
|
|
|
@ignore_warnings
|
|
def check_sample_weight_invariance(name, metric, y1, y2):
|
|
rng = np.random.RandomState(0)
|
|
sample_weight = rng.randint(1, 10, size=len(y1))
|
|
|
|
# top_k_accuracy_score always lead to a perfect score for k > 1 in the
|
|
# binary case
|
|
metric = partial(metric, k=1) if name == "top_k_accuracy_score" else metric
|
|
|
|
# check that unit weights gives the same score as no weight
|
|
unweighted_score = metric(y1, y2, sample_weight=None)
|
|
|
|
assert_allclose(
|
|
unweighted_score,
|
|
metric(y1, y2, sample_weight=np.ones(shape=len(y1))),
|
|
err_msg="For %s sample_weight=None is not equivalent to sample_weight=ones"
|
|
% name,
|
|
)
|
|
|
|
# check that the weighted and unweighted scores are unequal
|
|
weighted_score = metric(y1, y2, sample_weight=sample_weight)
|
|
|
|
# use context manager to supply custom error message
|
|
with pytest.raises(AssertionError):
|
|
assert_allclose(unweighted_score, weighted_score)
|
|
raise ValueError(
|
|
"Unweighted and weighted scores are unexpectedly "
|
|
"almost equal (%s) and (%s) "
|
|
"for %s" % (unweighted_score, weighted_score, name)
|
|
)
|
|
|
|
# check that sample_weight can be a list
|
|
weighted_score_list = metric(y1, y2, sample_weight=sample_weight.tolist())
|
|
assert_allclose(
|
|
weighted_score,
|
|
weighted_score_list,
|
|
err_msg=(
|
|
"Weighted scores for array and list "
|
|
"sample_weight input are not equal (%s != %s) for %s"
|
|
)
|
|
% (weighted_score, weighted_score_list, name),
|
|
)
|
|
|
|
# check that integer weights is the same as repeated samples
|
|
repeat_weighted_score = metric(
|
|
np.repeat(y1, sample_weight, axis=0),
|
|
np.repeat(y2, sample_weight, axis=0),
|
|
sample_weight=None,
|
|
)
|
|
assert_allclose(
|
|
weighted_score,
|
|
repeat_weighted_score,
|
|
err_msg="Weighting %s is not equal to repeating samples" % name,
|
|
)
|
|
|
|
# check that ignoring a fraction of the samples is equivalent to setting
|
|
# the corresponding weights to zero
|
|
sample_weight_subset = sample_weight[1::2]
|
|
sample_weight_zeroed = np.copy(sample_weight)
|
|
sample_weight_zeroed[::2] = 0
|
|
y1_subset = y1[1::2]
|
|
y2_subset = y2[1::2]
|
|
weighted_score_subset = metric(
|
|
y1_subset, y2_subset, sample_weight=sample_weight_subset
|
|
)
|
|
weighted_score_zeroed = metric(y1, y2, sample_weight=sample_weight_zeroed)
|
|
assert_allclose(
|
|
weighted_score_subset,
|
|
weighted_score_zeroed,
|
|
err_msg=(
|
|
"Zeroing weights does not give the same result as "
|
|
"removing the corresponding samples (%s != %s) for %s"
|
|
)
|
|
% (weighted_score_zeroed, weighted_score_subset, name),
|
|
)
|
|
|
|
if not name.startswith("unnormalized"):
|
|
# check that the score is invariant under scaling of the weights by a
|
|
# common factor
|
|
for scaling in [2, 0.3]:
|
|
assert_allclose(
|
|
weighted_score,
|
|
metric(y1, y2, sample_weight=sample_weight * scaling),
|
|
err_msg="%s sample_weight is not invariant under scaling" % name,
|
|
)
|
|
|
|
# Check that if number of samples in y_true and sample_weight are not
|
|
# equal, meaningful error is raised.
|
|
error_message = (
|
|
r"Found input variables with inconsistent numbers of "
|
|
r"samples: \[{}, {}, {}\]".format(
|
|
_num_samples(y1), _num_samples(y2), _num_samples(sample_weight) * 2
|
|
)
|
|
)
|
|
with pytest.raises(ValueError, match=error_message):
|
|
metric(y1, y2, sample_weight=np.hstack([sample_weight, sample_weight]))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name",
|
|
sorted(
|
|
set(ALL_METRICS).intersection(set(REGRESSION_METRICS))
|
|
- METRICS_WITHOUT_SAMPLE_WEIGHT
|
|
),
|
|
)
|
|
def test_regression_sample_weight_invariance(name):
|
|
n_samples = 50
|
|
random_state = check_random_state(0)
|
|
# regression
|
|
y_true = random_state.random_sample(size=(n_samples,))
|
|
y_pred = random_state.random_sample(size=(n_samples,))
|
|
metric = ALL_METRICS[name]
|
|
check_sample_weight_invariance(name, metric, y_true, y_pred)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name",
|
|
sorted(
|
|
set(ALL_METRICS)
|
|
- set(REGRESSION_METRICS)
|
|
- METRICS_WITHOUT_SAMPLE_WEIGHT
|
|
- METRIC_UNDEFINED_BINARY
|
|
),
|
|
)
|
|
def test_binary_sample_weight_invariance(name):
|
|
# binary
|
|
n_samples = 50
|
|
random_state = check_random_state(0)
|
|
y_true = random_state.randint(0, 2, size=(n_samples,))
|
|
y_pred = random_state.randint(0, 2, size=(n_samples,))
|
|
y_score = random_state.random_sample(size=(n_samples,))
|
|
metric = ALL_METRICS[name]
|
|
if name in THRESHOLDED_METRICS:
|
|
check_sample_weight_invariance(name, metric, y_true, y_score)
|
|
else:
|
|
check_sample_weight_invariance(name, metric, y_true, y_pred)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name",
|
|
sorted(
|
|
set(ALL_METRICS)
|
|
- set(REGRESSION_METRICS)
|
|
- METRICS_WITHOUT_SAMPLE_WEIGHT
|
|
- METRIC_UNDEFINED_BINARY_MULTICLASS
|
|
),
|
|
)
|
|
def test_multiclass_sample_weight_invariance(name):
|
|
# multiclass
|
|
n_samples = 50
|
|
random_state = check_random_state(0)
|
|
y_true = random_state.randint(0, 5, size=(n_samples,))
|
|
y_pred = random_state.randint(0, 5, size=(n_samples,))
|
|
y_score = random_state.random_sample(size=(n_samples, 5))
|
|
metric = ALL_METRICS[name]
|
|
if name in THRESHOLDED_METRICS:
|
|
# softmax
|
|
temp = np.exp(-y_score)
|
|
y_score_norm = temp / temp.sum(axis=-1).reshape(-1, 1)
|
|
check_sample_weight_invariance(name, metric, y_true, y_score_norm)
|
|
else:
|
|
check_sample_weight_invariance(name, metric, y_true, y_pred)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name",
|
|
sorted(
|
|
(MULTILABELS_METRICS | THRESHOLDED_MULTILABEL_METRICS | MULTIOUTPUT_METRICS)
|
|
- METRICS_WITHOUT_SAMPLE_WEIGHT
|
|
),
|
|
)
|
|
def test_multilabel_sample_weight_invariance(name):
|
|
# multilabel indicator
|
|
random_state = check_random_state(0)
|
|
_, ya = make_multilabel_classification(
|
|
n_features=1, n_classes=10, random_state=0, n_samples=50, allow_unlabeled=False
|
|
)
|
|
_, yb = make_multilabel_classification(
|
|
n_features=1, n_classes=10, random_state=1, n_samples=50, allow_unlabeled=False
|
|
)
|
|
y_true = np.vstack([ya, yb])
|
|
y_pred = np.vstack([ya, ya])
|
|
y_score = random_state.randint(1, 4, size=y_true.shape)
|
|
|
|
metric = ALL_METRICS[name]
|
|
if name in THRESHOLDED_METRICS:
|
|
check_sample_weight_invariance(name, metric, y_true, y_score)
|
|
else:
|
|
check_sample_weight_invariance(name, metric, y_true, y_pred)
|
|
|
|
|
|
@ignore_warnings
|
|
def test_no_averaging_labels():
|
|
# test labels argument when not using averaging
|
|
# in multi-class and multi-label cases
|
|
y_true_multilabel = np.array([[1, 1, 0, 0], [1, 1, 0, 0]])
|
|
y_pred_multilabel = np.array([[0, 0, 1, 1], [0, 1, 1, 0]])
|
|
y_true_multiclass = np.array([0, 1, 2])
|
|
y_pred_multiclass = np.array([0, 2, 3])
|
|
labels = np.array([3, 0, 1, 2])
|
|
_, inverse_labels = np.unique(labels, return_inverse=True)
|
|
|
|
for name in METRICS_WITH_AVERAGING:
|
|
for y_true, y_pred in [
|
|
[y_true_multiclass, y_pred_multiclass],
|
|
[y_true_multilabel, y_pred_multilabel],
|
|
]:
|
|
if name not in MULTILABELS_METRICS and y_pred.ndim > 1:
|
|
continue
|
|
|
|
metric = ALL_METRICS[name]
|
|
|
|
score_labels = metric(y_true, y_pred, labels=labels, average=None)
|
|
score = metric(y_true, y_pred, average=None)
|
|
assert_array_equal(score_labels, score[inverse_labels])
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", sorted(MULTILABELS_METRICS - {"unnormalized_multilabel_confusion_matrix"})
|
|
)
|
|
def test_multilabel_label_permutations_invariance(name):
|
|
random_state = check_random_state(0)
|
|
n_samples, n_classes = 20, 4
|
|
|
|
y_true = random_state.randint(0, 2, size=(n_samples, n_classes))
|
|
y_score = random_state.randint(0, 2, size=(n_samples, n_classes))
|
|
|
|
metric = ALL_METRICS[name]
|
|
score = metric(y_true, y_score)
|
|
|
|
for perm in permutations(range(n_classes), n_classes):
|
|
y_score_perm = y_score[:, perm]
|
|
y_true_perm = y_true[:, perm]
|
|
|
|
current_score = metric(y_true_perm, y_score_perm)
|
|
assert_almost_equal(score, current_score)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", sorted(THRESHOLDED_MULTILABEL_METRICS | MULTIOUTPUT_METRICS)
|
|
)
|
|
def test_thresholded_multilabel_multioutput_permutations_invariance(name):
|
|
random_state = check_random_state(0)
|
|
n_samples, n_classes = 20, 4
|
|
y_true = random_state.randint(0, 2, size=(n_samples, n_classes))
|
|
y_score = random_state.normal(size=y_true.shape)
|
|
|
|
# Makes sure all samples have at least one label. This works around errors
|
|
# when running metrics where average="sample"
|
|
y_true[y_true.sum(1) == 4, 0] = 0
|
|
y_true[y_true.sum(1) == 0, 0] = 1
|
|
|
|
metric = ALL_METRICS[name]
|
|
score = metric(y_true, y_score)
|
|
|
|
for perm in permutations(range(n_classes), n_classes):
|
|
y_score_perm = y_score[:, perm]
|
|
y_true_perm = y_true[:, perm]
|
|
|
|
current_score = metric(y_true_perm, y_score_perm)
|
|
if metric == mean_absolute_percentage_error:
|
|
assert np.isfinite(current_score)
|
|
assert current_score > 1e6
|
|
# Here we are not comparing the values in case of MAPE because
|
|
# whenever y_true value is exactly zero, the MAPE value doesn't
|
|
# signify anything. Thus, in this case we are just expecting
|
|
# very large finite value.
|
|
else:
|
|
assert_almost_equal(score, current_score)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", sorted(set(THRESHOLDED_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS)
|
|
)
|
|
def test_thresholded_metric_permutation_invariance(name):
|
|
n_samples, n_classes = 100, 3
|
|
random_state = check_random_state(0)
|
|
|
|
y_score = random_state.rand(n_samples, n_classes)
|
|
temp = np.exp(-y_score)
|
|
y_score = temp / temp.sum(axis=-1).reshape(-1, 1)
|
|
y_true = random_state.randint(0, n_classes, size=n_samples)
|
|
|
|
metric = ALL_METRICS[name]
|
|
score = metric(y_true, y_score)
|
|
for perm in permutations(range(n_classes), n_classes):
|
|
inverse_perm = np.zeros(n_classes, dtype=int)
|
|
inverse_perm[list(perm)] = np.arange(n_classes)
|
|
y_score_perm = y_score[:, inverse_perm]
|
|
y_true_perm = np.take(perm, y_true)
|
|
|
|
current_score = metric(y_true_perm, y_score_perm)
|
|
assert_almost_equal(score, current_score)
|
|
|
|
|
|
@pytest.mark.parametrize("metric_name", CLASSIFICATION_METRICS)
|
|
def test_metrics_consistent_type_error(metric_name):
|
|
# check that an understable message is raised when the type between y_true
|
|
# and y_pred mismatch
|
|
rng = np.random.RandomState(42)
|
|
y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=object)
|
|
y2 = rng.randint(0, 2, size=y1.size)
|
|
|
|
err_msg = "Labels in y_true and y_pred should be of the same type."
|
|
with pytest.raises(TypeError, match=err_msg):
|
|
CLASSIFICATION_METRICS[metric_name](y1, y2)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"metric, y_pred_threshold",
|
|
[
|
|
(average_precision_score, True),
|
|
(brier_score_loss, True),
|
|
(f1_score, False),
|
|
(partial(fbeta_score, beta=1), False),
|
|
(jaccard_score, False),
|
|
(precision_recall_curve, True),
|
|
(precision_score, False),
|
|
(recall_score, False),
|
|
(roc_curve, True),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("dtype_y_str", [str, object])
|
|
def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
|
|
# check that the error message if `pos_label` is not specified and the
|
|
# targets is made of strings.
|
|
rng = np.random.RandomState(42)
|
|
y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=dtype_y_str)
|
|
y2 = rng.randint(0, 2, size=y1.size)
|
|
|
|
if not y_pred_threshold:
|
|
y2 = np.array(["spam", "eggs"], dtype=dtype_y_str)[y2]
|
|
|
|
err_msg_pos_label_None = (
|
|
"y_true takes value in {'eggs', 'spam'} and pos_label is not "
|
|
"specified: either make y_true take value in {0, 1} or {-1, 1} or "
|
|
"pass pos_label explicit"
|
|
)
|
|
err_msg_pos_label_1 = (
|
|
r"pos_label=1 is not a valid label. It should be one of " r"\['eggs', 'spam'\]"
|
|
)
|
|
|
|
pos_label_default = signature(metric).parameters["pos_label"].default
|
|
|
|
err_msg = err_msg_pos_label_1 if pos_label_default == 1 else err_msg_pos_label_None
|
|
with pytest.raises(ValueError, match=err_msg):
|
|
metric(y1, y2)
|
|
|
|
|
|
def check_array_api_metric(
|
|
metric, array_namespace, device, dtype_name, y_true_np, y_pred_np, sample_weight
|
|
):
|
|
xp = _array_api_for_tests(array_namespace, device)
|
|
|
|
y_true_xp = xp.asarray(y_true_np, device=device)
|
|
y_pred_xp = xp.asarray(y_pred_np, device=device)
|
|
|
|
metric_np = metric(y_true_np, y_pred_np, sample_weight=sample_weight)
|
|
|
|
if sample_weight is not None:
|
|
sample_weight = xp.asarray(sample_weight, device=device)
|
|
|
|
with config_context(array_api_dispatch=True):
|
|
metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)
|
|
|
|
assert_allclose(
|
|
metric_xp,
|
|
metric_np,
|
|
atol=_atol_for_type(dtype_name),
|
|
)
|
|
|
|
|
|
def check_array_api_binary_classification_metric(
|
|
metric, array_namespace, device, dtype_name
|
|
):
|
|
y_true_np = np.array([0, 0, 1, 1])
|
|
y_pred_np = np.array([0, 1, 0, 1])
|
|
|
|
check_array_api_metric(
|
|
metric,
|
|
array_namespace,
|
|
device,
|
|
dtype_name,
|
|
y_true_np=y_true_np,
|
|
y_pred_np=y_pred_np,
|
|
sample_weight=None,
|
|
)
|
|
|
|
sample_weight = np.array([0.0, 0.1, 2.0, 1.0], dtype=dtype_name)
|
|
|
|
check_array_api_metric(
|
|
metric,
|
|
array_namespace,
|
|
device,
|
|
dtype_name,
|
|
y_true_np=y_true_np,
|
|
y_pred_np=y_pred_np,
|
|
sample_weight=sample_weight,
|
|
)
|
|
|
|
|
|
def check_array_api_multiclass_classification_metric(
|
|
metric, array_namespace, device, dtype_name
|
|
):
|
|
y_true_np = np.array([0, 1, 2, 3])
|
|
y_pred_np = np.array([0, 1, 0, 2])
|
|
|
|
check_array_api_metric(
|
|
metric,
|
|
array_namespace,
|
|
device,
|
|
dtype_name,
|
|
y_true_np=y_true_np,
|
|
y_pred_np=y_pred_np,
|
|
sample_weight=None,
|
|
)
|
|
|
|
sample_weight = np.array([0.0, 0.1, 2.0, 1.0], dtype=dtype_name)
|
|
|
|
check_array_api_metric(
|
|
metric,
|
|
array_namespace,
|
|
device,
|
|
dtype_name,
|
|
y_true_np=y_true_np,
|
|
y_pred_np=y_pred_np,
|
|
sample_weight=sample_weight,
|
|
)
|
|
|
|
|
|
array_api_metric_checkers = {
|
|
accuracy_score: [
|
|
check_array_api_binary_classification_metric,
|
|
check_array_api_multiclass_classification_metric,
|
|
],
|
|
zero_one_loss: [
|
|
check_array_api_binary_classification_metric,
|
|
check_array_api_multiclass_classification_metric,
|
|
],
|
|
}
|
|
|
|
|
|
def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers):
|
|
for metric, checkers in metric_checkers.items():
|
|
for checker in checkers:
|
|
yield metric, checker
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
|
|
)
|
|
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
|
|
def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
|
|
check_func(metric, array_namespace, device, dtype_name)
|