diff --git a/orca_python/metrics/__init__.py b/orca_python/metrics/__init__.py index d83bf01..17b21bd 100644 --- a/orca_python/metrics/__init__.py +++ b/orca_python/metrics/__init__.py @@ -6,7 +6,6 @@ ccr, gm, gmsec, - greater_is_better, mae, mmae, ms, @@ -16,9 +15,14 @@ tkendall, wkappa, ) +from .utils import ( + compute_metric, + get_metric_names, + greater_is_better, + load_metric_as_scorer, +) __all__ = [ - "greater_is_better", "ccr", "amae", "gm", @@ -32,4 +36,8 @@ "spearman", "rps", "accuracy_off1", + "get_metric_names", + "greater_is_better", + "load_metric_as_scorer", + "compute_metric", ] diff --git a/orca_python/metrics/metrics.py b/orca_python/metrics/metrics.py index 304b392..cc1e6ed 100644 --- a/orca_python/metrics/metrics.py +++ b/orca_python/metrics/metrics.py @@ -7,48 +7,6 @@ from sklearn.metrics import confusion_matrix, recall_score -def greater_is_better(metric_name): - """Determine if greater values indicate better classification performance. - - Needed when declaring a new scorer through make_scorer from sklearn. - - Parameters - ---------- - metric_name : str - Name of the metric. - - Returns - ------- - greater_is_better : bool - True if greater values indicate better classification performance, False otherwise. - - Examples - -------- - >>> from orca_python.metrics.metrics import greater_is_better - >>> greater_is_better("ccr") - True - >>> greater_is_better("mze") - False - >>> greater_is_better("mae") - False - - """ - greater_is_better_metrics = [ - "ccr", - "ms", - "gm", - "gmsec", - "tkendall", - "wkappa", - "spearman", - "accuracy_off1", - ] - if metric_name in greater_is_better_metrics: - return True - else: - return False - - def ccr(y_true, y_pred): """Calculate the Correctly Classified Ratio. diff --git a/orca_python/metrics/tests/test_metrics.py b/orca_python/metrics/tests/test_metrics.py index 07fbb07..5b9e830 100644 --- a/orca_python/metrics/tests/test_metrics.py +++ b/orca_python/metrics/tests/test_metrics.py @@ -10,7 +10,6 @@ ccr, gm, gmsec, - greater_is_better, mae, mmae, ms, @@ -22,23 +21,6 @@ ) -def test_greater_is_better(): - """Test the greater_is_better function.""" - assert greater_is_better("accuracy_off1") - assert greater_is_better("ccr") - assert greater_is_better("gm") - assert greater_is_better("gmsec") - assert not greater_is_better("mae") - assert not greater_is_better("mmae") - assert not greater_is_better("amae") - assert greater_is_better("ms") - assert not greater_is_better("mze") - assert not greater_is_better("rps") - assert greater_is_better("tkendall") - assert greater_is_better("wkappa") - assert greater_is_better("spearman") - - def test_accuracy_off1(): """Test the Accuracy that allows errors in adjacent classes.""" y_true = np.array([0, 1, 2, 3, 4, 5]) diff --git a/orca_python/metrics/tests/test_utils.py b/orca_python/metrics/tests/test_utils.py new file mode 100644 index 0000000..2d93ee8 --- /dev/null +++ b/orca_python/metrics/tests/test_utils.py @@ -0,0 +1,171 @@ +"""Tests for the metrics module utilities.""" + +import numpy.testing as npt +import pytest + +from orca_python.metrics import ( + accuracy_off1, + amae, + ccr, + gm, + gmsec, + mae, + mmae, + ms, + mze, + rps, + spearman, + tkendall, + wkappa, +) +from orca_python.metrics.utils import ( + _METRICS, + compute_metric, + get_metric_names, + greater_is_better, + load_metric_as_scorer, +) + + +def test_get_metric_names(): + """Test that get_metric_names returns all available metric names.""" + all_metrics = get_metric_names() + expected_names = list(_METRICS.keys()) + + assert type(all_metrics) is list + assert all_metrics[:3] == ["accuracy_off1", "amae", "ccr"] + assert "rps" in all_metrics + npt.assert_array_equal(sorted(all_metrics), sorted(expected_names)) + + +@pytest.mark.parametrize( + "metric_name, gib", + [ + ("accuracy_off1", True), + ("amae", False), + ("ccr", True), + ("gm", True), + ("gmsec", True), + ("mae", False), + ("mmae", False), + ("ms", True), + ("mze", False), + ("rps", False), + ("spearman", True), + ("tkendall", True), + ("wkappa", True), + ], +) +def test_greater_is_better(metric_name, gib): + """Test that greater_is_better returns the correct boolean for each metric.""" + assert greater_is_better(metric_name) == gib + + +def test_greater_is_better_invalid_name(): + """Test that greater_is_better raises an error for an invalid metric name.""" + error_msg = "Unrecognized metric name: 'roc_auc'." + + with pytest.raises(KeyError, match=error_msg): + greater_is_better("roc_auc") + + +@pytest.mark.parametrize( + "metric_name, metric", + [ + ("rps", rps), + ("ccr", ccr), + ("accuracy_off1", accuracy_off1), + ("gm", gm), + ("gmsec", gmsec), + ("mae", mae), + ("mmae", mmae), + ("amae", amae), + ("ms", ms), + ("mze", mze), + ("tkendall", tkendall), + ("wkappa", wkappa), + ("spearman", spearman), + ], +) +def test_load_metric_as_scorer(metric_name, metric): + """Test that load_metric_as_scorer correctly loads the expected metric.""" + metric_func = load_metric_as_scorer(metric_name) + + assert metric_func._score_func == metric + assert metric_func._sign == (1 if greater_is_better(metric_name) else -1) + + +@pytest.mark.parametrize( + "metric_name, metric", + [ + ("ccr", ccr), + ("accuracy_off1", accuracy_off1), + ("gm", gm), + ("gmsec", gmsec), + ("mae", mae), + ("mmae", mmae), + ("amae", amae), + ("ms", ms), + ("mze", mze), + ("tkendall", tkendall), + ("wkappa", wkappa), + ("spearman", spearman), + ], +) +def test_correct_metric_output(metric_name, metric): + """Test that the loaded metric function produces the same output as the + original metric.""" + y_true = [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3] + y_pred = [1, 3, 3, 1, 2, 3, 1, 2, 2, 1, 3, 1, 1, 2, 2, 2, 3, 3, 1, 3] + metric_func = load_metric_as_scorer(metric_name) + metric_true = metric(y_true, y_pred) + metric_pred = metric_func._score_func(y_true, y_pred) + + npt.assert_almost_equal(metric_pred, metric_true, decimal=6) + + +def test_load_metric_invalid_name(): + """Test that loading an invalid metric raises the correct exception.""" + error_msg = "metric_name must be a string." + with pytest.raises(TypeError, match=error_msg): + load_metric_as_scorer(123) + + error_msg = "Unrecognized metric name: 'roc_auc'." + with pytest.raises(KeyError, match=error_msg): + load_metric_as_scorer("roc_auc") + + +@pytest.mark.parametrize( + "metric_name", + [ + "ccr", + "accuracy_off1", + "gm", + "gmsec", + "mae", + "mmae", + "amae", + "ms", + "mze", + "tkendall", + "wkappa", + "spearman", + ], +) +def test_compute_metric(metric_name) -> None: + """Test that compute_metric returns the correct metric value.""" + y_true = [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3] + y_pred = [1, 3, 3, 1, 2, 3, 1, 2, 2, 1, 3, 1, 1, 2, 2, 2, 3, 3, 1, 3] + metric_value = compute_metric(metric_name, y_true, y_pred) + metric_func = load_metric_as_scorer(metric_name) + metric_true = metric_func._score_func(y_true, y_pred) + + npt.assert_almost_equal(metric_value, metric_true, decimal=6) + + +def test_compute_metric_invalid_name(): + """Test that compute_metric raises an error for an invalid metric name.""" + error_msg = "Unrecognized metric name: 'roc_auc'." + + with pytest.raises(KeyError, match=error_msg): + compute_metric("roc_auc", [1, 2, 3], [1, 2, 3]) diff --git a/orca_python/metrics/utils.py b/orca_python/metrics/utils.py new file mode 100644 index 0000000..ca26c1e --- /dev/null +++ b/orca_python/metrics/utils.py @@ -0,0 +1,207 @@ +"""Utility functions for accessing and using classification metrics by name.""" + +from sklearn.metrics import make_scorer + +from orca_python.metrics import ( + accuracy_off1, + amae, + ccr, + gm, + gmsec, + mae, + mmae, + ms, + mze, + rps, + spearman, + tkendall, + wkappa, +) + +# Mapping from metric names to their functions +_METRICS = { + "accuracy_off1": accuracy_off1, + "amae": amae, + "ccr": ccr, + "gm": gm, + "gmsec": gmsec, + "mae": mae, + "mmae": mmae, + "ms": ms, + "mze": mze, + "rps": rps, + "spearman": spearman, + "tkendall": tkendall, + "wkappa": wkappa, +} + +# Indicates whether a higher score means better performance +_GREATER_IS_BETTER = { + "accuracy_off1": True, + "amae": False, + "ccr": True, + "gm": True, + "gmsec": True, + "mae": False, + "mmae": False, + "ms": True, + "mze": False, + "rps": False, + "spearman": True, + "tkendall": True, + "wkappa": True, +} + + +def get_metric_names(): + """Get the names of all available metrics. + + These names can be passed to :func:`~orca_python.metrics.compute_metric` to + compute the metric value. + + Returns + ------- + list of str + Names of all available metrics. + + Examples + -------- + >>> from orca_python.metrics import get_metric_names + >>> all_metrics = get_metric_names() + >>> type(all_metrics) + + >>> all_metrics[:3] + ['accuracy_off1', 'amae', 'ccr'] + >>> "rps" in all_metrics + True + + """ + return sorted(_METRICS.keys()) + + +def greater_is_better(metric_name): + """Determine if greater values indicate better classification performance. + + Needed when declaring a new scorer through make_scorer from sklearn. + + Parameters + ---------- + metric_name : str + Name of the metric. + + Returns + ------- + greater_is_better : bool + True if greater values are better, False otherwise. + + Raises + ------ + KeyError + If the metric name is not recognized. + + Examples + -------- + >>> from orca_python.metrics import greater_is_better + >>> greater_is_better("ccr") + True + >>> greater_is_better("mze") + False + >>> greater_is_better("mae") + False + + """ + try: + return _GREATER_IS_BETTER[metric_name.lower().strip()] + except KeyError: + raise KeyError(f"Unrecognized metric name: '{metric_name}'.") + + +def load_metric_as_scorer(metric_name): + """Load a metric function by name and return a scorer compatible with + sklearn. + + Parameters + ---------- + metric_name : str + Name of the metric. + + Returns + ------- + callable + A scikit-learn compatible scorer. + + Raises + ------ + TypeError + If metric_name is not a string. + + ValueError + If the metric name is not implemented. + + Examples + -------- + >>> from orca_python.metrics import load_metric_as_scorer + >>> scorer = load_metric_as_scorer("ccr") + >>> type(scorer) + + >>> load_metric_as_scorer("mae") + make_scorer(mae, greater_is_better=False, response_method='predict') + + """ + if not isinstance(metric_name, str): + raise TypeError("metric_name must be a string.") + + metric_name = metric_name.lower().strip() + + try: + metric_func = _METRICS[metric_name] + except KeyError: + raise KeyError(f"Unrecognized metric name: '{metric_name}'.") + + gib = greater_is_better(metric_name) + scorer = make_scorer(metric_func, greater_is_better=gib) + scorer.metric_name = metric_name + return scorer + + +def compute_metric(metric_name, y_true, y_pred): + """Compute the value of a metric from true and predicted labels. + + Parameters + ---------- + metric_name : str + Name of the metric. + + y_true : np.ndarray, shape (n_samples,) + Ground truth labels. + + y_pred : np.ndarray, shape (n_samples,) + Predicted labels. + + Returns + ------- + float + Numeric value of the classification metric. + + Raises + ------ + KeyError + If the metric name is not recognized. + + Examples + -------- + >>> from orca_python.metrics import compute_metric + >>> y_true = [0, 1, 2, 1, 0] + >>> y_pred = [0, 1, 1, 1, 0] + >>> compute_metric("ccr", y_true, y_pred) + 0.8 + >>> compute_metric("mae", y_true, y_pred) + 0.2 + + """ + try: + metric_func = _METRICS[metric_name] + except KeyError: + raise KeyError(f"Unrecognized metric name: '{metric_name}'.") + + return metric_func(y_true, y_pred)