diff --git a/orca_python/classifiers/NNOP.py b/orca_python/classifiers/NNOP.py index 9cb6e4d..34f56a7 100644 --- a/orca_python/classifiers/NNOP.py +++ b/orca_python/classifiers/NNOP.py @@ -1,10 +1,12 @@ """Neural Network with Ordered Partitions (NNOP).""" import math as math +from numbers import Integral, Real import numpy as np import scipy -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator, ClassifierMixin, _fit_context +from sklearn.utils._param_validation import Interval from sklearn.utils.multiclass import unique_labels from sklearn.utils.validation import check_array, check_is_fitted, check_X_y @@ -89,12 +91,20 @@ class NNOP(BaseEstimator, ClassifierMixin): """ + _parameter_constraints: dict = { + "epsilon_init": [Interval(Real, 0.0, None, closed="neither")], + "n_hidden": [Interval(Integral, 1, None, closed="left")], + "max_iter": [Interval(Integral, 1, None, closed="left")], + "lambda_value": [Interval(Real, 0.0, None, closed="neither")], + } + def __init__(self, epsilon_init=0.5, n_hidden=50, max_iter=500, lambda_value=0.01): self.epsilon_init = epsilon_init self.n_hidden = n_hidden self.max_iter = max_iter self.lambda_value = lambda_value + @_fit_context(prefer_skip_nested_validation=True) def fit(self, X, y): """Fit the model with the training data. diff --git a/orca_python/classifiers/NNPOM.py b/orca_python/classifiers/NNPOM.py index 0a96c67..74c69de 100644 --- a/orca_python/classifiers/NNPOM.py +++ b/orca_python/classifiers/NNPOM.py @@ -1,10 +1,12 @@ """Neural Network based on Proportional Odd Model (NNPOM).""" import math as math +from numbers import Integral, Real import numpy as np import scipy -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator, ClassifierMixin, _fit_context +from sklearn.utils._param_validation import Interval from sklearn.utils.multiclass import unique_labels from sklearn.utils.validation import check_array, check_is_fitted, check_X_y @@ -90,12 +92,20 @@ class NNPOM(BaseEstimator, ClassifierMixin): """ + _parameter_constraints: dict = { + "epsilon_init": [Interval(Real, 0.0, None, closed="neither")], + "n_hidden": [Interval(Integral, 1, None, closed="left")], + "max_iter": [Interval(Integral, 1, None, closed="left")], + "lambda_value": [Interval(Real, 0.0, None, closed="neither")], + } + def __init__(self, epsilon_init=0.5, n_hidden=50, max_iter=500, lambda_value=0.01): self.epsilon_init = epsilon_init self.n_hidden = n_hidden self.max_iter = max_iter self.lambda_value = lambda_value + @_fit_context(prefer_skip_nested_validation=True) def fit(self, X, y): """Fit the model with the training data. diff --git a/orca_python/classifiers/OrdinalDecomposition.py b/orca_python/classifiers/OrdinalDecomposition.py index 9202fe1..da3339d 100644 --- a/orca_python/classifiers/OrdinalDecomposition.py +++ b/orca_python/classifiers/OrdinalDecomposition.py @@ -1,14 +1,12 @@ """OrdinalDecomposition ensemble.""" import numpy as np -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator, ClassifierMixin, _fit_context +from sklearn.utils._param_validation import StrOptions from sklearn.utils.validation import check_array, check_is_fitted, check_X_y from orca_python.utilities import load_classifier -# from sys import path -# path.append('..') - class OrdinalDecomposition(BaseEstimator, ClassifierMixin): """OrdinalDecomposition ensemble classifier. @@ -88,6 +86,26 @@ class OrdinalDecomposition(BaseEstimator, ClassifierMixin): """ + _parameter_constraints: dict = { + "dtype": [ + StrOptions( + { + "ordered_partitions", + "one_vs_next", + "one_vs_followers", + "one_vs_previous", + } + ) + ], + "decision_method": [ + StrOptions( + {"exponential_loss", "hinge_loss", "logarithmic_loss", "frank_hall"} + ) + ], + "base_classifier": [str], + "parameters": [dict], + } + def __init__( self, dtype="ordered_partitions", @@ -100,6 +118,7 @@ def __init__( self.base_classifier = base_classifier self.parameters = parameters + @_fit_context(prefer_skip_nested_validation=True) def fit(self, X, y): """Fit the model with the training data. diff --git a/orca_python/classifiers/REDSVM.py b/orca_python/classifiers/REDSVM.py index 9fb721f..7bf9c53 100644 --- a/orca_python/classifiers/REDSVM.py +++ b/orca_python/classifiers/REDSVM.py @@ -1,14 +1,15 @@ """Reduction from ordinal regression to binary SVM (REDSVM).""" +from numbers import Integral, Real + import numpy as np -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator, ClassifierMixin, _fit_context +from sklearn.utils._param_validation import Interval, StrOptions from sklearn.utils.multiclass import unique_labels from sklearn.utils.validation import check_array, check_is_fitted, check_X_y from orca_python.classifiers.libsvmRank.python import svm -# from .libsvmRank.python import svm - class REDSVM(BaseEstimator, ClassifierMixin): """Reduction from ordinal regression to binary SVM classifiers. @@ -37,13 +38,16 @@ class REDSVM(BaseEstimator, ClassifierMixin): degree : int, default=3 Set degree in kernel function. - gamma : float, default=1/n_features - Set gamma in kernel function. + gamma : {'scale', 'auto'} or float, default=1.0 + Kernel coefficient determining the influence of individual training samples: + - 'scale': 1 / (n_features * X.var()) + - 'auto': 1 / n_features + - float: Must be non-negative. coef0 : float, default=0 Set coef0 in kernel function. - shrinking : int, default=1 + shrinking : bool, default=True Set whether to use the shrinking heuristics. tol : float, default=0.001 @@ -74,14 +78,42 @@ class REDSVM(BaseEstimator, ClassifierMixin): """ + _parameter_constraints: dict = { + "C": [Interval(Real, 0.0, None, closed="neither")], + "kernel": [ + StrOptions( + { + "linear", + "poly", + "rbf", + "sigmoid", + "stump", + "perceptron", + "laplacian", + "exponential", + "precomputed", + } + ) + ], + "degree": [Interval(Integral, 0, None, closed="left")], + "gamma": [ + StrOptions({"scale", "auto"}), + Interval(Real, 0.0, None, closed="neither"), + ], + "coef0": [Interval(Real, None, None, closed="neither")], + "shrinking": ["boolean"], + "tol": [Interval(Real, 0.0, None, closed="neither")], + "cache_size": [Interval(Real, 0.0, None, closed="neither")], + } + def __init__( self, C=1, kernel="rbf", degree=3, - gamma=None, + gamma="auto", coef0=0, - shrinking=1, + shrinking=True, tol=0.001, cache_size=100, ): @@ -94,6 +126,7 @@ def __init__( self.tol = tol self.cache_size = cache_size + @_fit_context(prefer_skip_nested_validation=True) def fit(self, X, y): """Fit the model with the training data. @@ -117,14 +150,24 @@ def fit(self, X, y): If parameters are invalid or data has wrong format. """ + # Additional strict validation for boolean parameters + if not isinstance(self.shrinking, bool): + raise ValueError( + f"The 'shrinking' parameter must be of type bool. " + f"Got {type(self.shrinking).__name__} instead." + ) + # Check that X and y have correct shape X, y = check_X_y(X, y) # Store the classes seen during fit self.classes_ = unique_labels(y) - # Set the default g value if necessary - if self.gamma is None: - self.gamma = 1 / np.size(X, 1) + # Set default gamma value if not specified + gamma_value = self.gamma + if self.gamma == "auto": + gamma_value = 1.0 / X.shape[1] + elif self.gamma == "scale": + gamma_value = 1.0 / (X.shape[1] * X.var()) # Map kernel type kernel_type_mapping = { @@ -138,18 +181,18 @@ def fit(self, X, y): "exponential": 7, "precomputed": 8, } - kernel_type = kernel_type_mapping.get(self.kernel, -1) + kernel_type = kernel_type_mapping[self.kernel] # Fit the model options = "-s 5 -t {} -d {} -g {} -r {} -c {} -m {} -e {} -h {} -q".format( str(kernel_type), str(self.degree), - str(self.gamma), + str(gamma_value), str(self.coef0), str(self.C), str(self.cache_size), str(self.tol), - str(self.shrinking), + str(1 if self.shrinking else 0), ) self.model_ = svm.fit(y.tolist(), X.tolist(), options) @@ -184,6 +227,6 @@ def predict(self, X): # Input validation X = check_array(X) - y_pred = svm.predict(X.tolist(), self.model_) + y_pred = np.array(svm.predict(X.tolist(), self.model_)) return y_pred diff --git a/orca_python/classifiers/SVOREX.py b/orca_python/classifiers/SVOREX.py index 437a6db..4a8688a 100644 --- a/orca_python/classifiers/SVOREX.py +++ b/orca_python/classifiers/SVOREX.py @@ -1,10 +1,13 @@ """Support Vector for Ordinal Regression (Explicit constraints) (SVOREX).""" -from sklearn.base import BaseEstimator, ClassifierMixin +from numbers import Integral, Real + +import numpy as np +from sklearn.base import BaseEstimator, ClassifierMixin, _fit_context +from sklearn.utils._param_validation import Interval, StrOptions from sklearn.utils.multiclass import unique_labels from sklearn.utils.validation import check_array, check_is_fitted, check_X_y -# from .svorex import svorex from orca_python.classifiers.svorex import svorex @@ -56,6 +59,22 @@ class SVOREX(BaseEstimator, ClassifierMixin): """ + _parameter_constraints: dict = { + "C": [Interval(Real, 0.0, None, closed="neither")], + "kernel": [ + StrOptions( + { + "gaussian", + "linear", + "poly", + } + ) + ], + "degree": [Interval(Integral, 0, None, closed="left")], + "tol": [Interval(Real, 0.0, None, closed="neither")], + "kappa": [Interval(Real, 0.0, None, closed="neither")], + } + def __init__(self, C=1.0, kernel="gaussian", degree=2, tol=0.001, kappa=1): self.C = C self.kernel = kernel @@ -63,6 +82,7 @@ def __init__(self, C=1.0, kernel="gaussian", degree=2, tol=0.001, kappa=1): self.tol = tol self.kappa = kappa + @_fit_context(prefer_skip_nested_validation=True) def fit(self, X, y): """Fit the model with the training data. @@ -135,6 +155,6 @@ def predict(self, X): # Input validation X = check_array(X) - y_pred = svorex.predict(X.tolist(), self.model_) + y_pred = np.array(svorex.predict(X.tolist(), self.model_)) return y_pred diff --git a/orca_python/classifiers/tests/test_nnop.py b/orca_python/classifiers/tests/test_nnop.py index 02b5240..46da287 100644 --- a/orca_python/classifiers/tests/test_nnop.py +++ b/orca_python/classifiers/tests/test_nnop.py @@ -21,16 +21,36 @@ def y(): @pytest.mark.parametrize( "param_name, invalid_value", [ + ("epsilon_init", 0), + ("epsilon_init", -1), ("n_hidden", -1), ("max_iter", -1), + ("lambda_value", -1e-5), ], ) -def test_nnop_fit_hyperparameters_validation(X, y, param_name, invalid_value): - """Test that hyperparameters are validated.""" +def test_nnop_hyperparameter_value_validation(X, y, param_name, invalid_value): + """Test that NNOP raises ValueError for invalid of hyperparameters.""" classifier = NNOP(**{param_name: invalid_value}) - model = classifier.fit(X, y) - assert model is None, "The NNOP fit method doesnt return Null on error" + with pytest.raises(ValueError, match=rf"The '{param_name}' parameter.*"): + classifier.fit(X, y) + + +@pytest.mark.parametrize( + "param_name, invalid_value", + [ + ("epsilon_init", "high"), + ("n_hidden", 5.5), + ("max_iter", 2.5), + ("lambda_value", "tight"), + ], +) +def test_nnop_hyperparameter_type_validation(X, y, param_name, invalid_value): + """Test that NNOP raises ValueError for invalid types of hyperparameters.""" + classifier = NNOP(**{param_name: invalid_value}) + + with pytest.raises(ValueError, match=rf"The '{param_name}' parameter.*"): + classifier.fit(X, y) def test_nnop_fit_input_validation(X, y): diff --git a/orca_python/classifiers/tests/test_nnpom.py b/orca_python/classifiers/tests/test_nnpom.py index f0a4e38..c992db4 100644 --- a/orca_python/classifiers/tests/test_nnpom.py +++ b/orca_python/classifiers/tests/test_nnpom.py @@ -21,16 +21,36 @@ def y(): @pytest.mark.parametrize( "param_name, invalid_value", [ + ("epsilon_init", 0), + ("epsilon_init", -1), ("n_hidden", -1), ("max_iter", -1), + ("lambda_value", -1e-5), ], ) -def test_nnpom_fit_hyperparameters_validation(X, y, param_name, invalid_value): - """Test that hyperparameters are validated.""" +def test_nnpom_hyperparameter_value_validation(X, y, param_name, invalid_value): + """Test that NNPOM raises ValueError for invalid of hyperparameters.""" classifier = NNPOM(**{param_name: invalid_value}) - model = classifier.fit(X, y) - assert model is None, "The NNPOM fit method doesnt return Null on error" + with pytest.raises(ValueError, match=rf"The '{param_name}' parameter.*"): + classifier.fit(X, y) + + +@pytest.mark.parametrize( + "param_name, invalid_value", + [ + ("epsilon_init", "high"), + ("n_hidden", 5.5), + ("max_iter", 2.5), + ("lambda_value", "tight"), + ], +) +def test_nnpom_hyperparameter_type_validation(X, y, param_name, invalid_value): + """Test that NNPOM raises ValueError for invalid types of hyperparameters.""" + classifier = NNPOM(**{param_name: invalid_value}) + + with pytest.raises(ValueError, match=rf"The '{param_name}' parameter.*"): + classifier.fit(X, y) def test_nnpom_fit_input_validation(X, y): diff --git a/orca_python/classifiers/tests/test_ordinal_decomposition.py b/orca_python/classifiers/tests/test_ordinal_decomposition.py index 8a335e8..74bbbe0 100644 --- a/orca_python/classifiers/tests/test_ordinal_decomposition.py +++ b/orca_python/classifiers/tests/test_ordinal_decomposition.py @@ -32,6 +32,46 @@ def test_ordinal_decomposition(X, y): npt.assert_array_equal(y_pred, y) +@pytest.mark.parametrize( + "param_name, invalid_value", + [ + ("dtype", "one_vs_all"), + ("dtype", "frank_hall"), + ("decision_method", "invalid"), + ("decision_method", "one_vs_next"), + ], +) +def test_ordinal_decomposition_hyperparameter_value_validation( + X, y, param_name, invalid_value +): + """Test that OrdinalDecomposition raises ValueError for invalid of + hyperparameters.""" + classifier = OrdinalDecomposition(**{param_name: invalid_value}) + + with pytest.raises(ValueError, match=rf"The '{param_name}' parameter.*"): + classifier.fit(X, y) + + +@pytest.mark.parametrize( + "param_name, invalid_value", + [ + ("dtype", ["ordered_partitions"]), + ("decision_method", 0), + ("base_classifier", 3), + ("parameters", "tol"), + ("parameters", []), + ], +) +def test_ordinal_decomposition_hyperparameter_type_validation( + X, y, param_name, invalid_value +): + """Test that OrdinalDecomposition raises ValueError for invalid types of hyperparameters.""" + classifier = OrdinalDecomposition(**{param_name: invalid_value}) + + with pytest.raises(ValueError, match=rf"The '{param_name}' parameter.*"): + classifier.fit(X, y) + + def test_ordinal_decomposition_fit_input_validation(X, y): """Test that input data is validated.""" X_invalid = X[:-1, :-1] diff --git a/orca_python/classifiers/tests/test_redsvm.py b/orca_python/classifiers/tests/test_redsvm.py index ec166ab..6d5f1b5 100644 --- a/orca_python/classifiers/tests/test_redsvm.py +++ b/orca_python/classifiers/tests/test_redsvm.py @@ -46,7 +46,7 @@ def test_redsvm_predict_matches_expected(kernel, expected_file): degree=2, gamma=0.1, coef0=0.5, - shrinking=0, + shrinking=False, tol=0.005, cache_size=150, ) @@ -63,28 +63,45 @@ def test_redsvm_predict_matches_expected(kernel, expected_file): @pytest.mark.parametrize( - "param_name, invalid_value, error_msg", + "param_name, invalid_value", [ - ("kernel", "unknown", "unknown kernel type"), - ("cache_size", -1, "cache_size <= 0"), - ("tol", -1, "eps <= 0"), - ("shrinking", 2, "shrinking != 0 and shrinking != 1"), - ( - "kernel", - "precomputed", - "Wrong input format: sample_serial_number out of range", - ), + ("C", 0), + ("C", -1), + ("degree", -1), + ("gamma", -0.5), + ("shrinking", 2), + ("tol", -1e-5), + ("cache_size", 0), + ("kernel", "unknown"), + ("gamma", "invalid_string"), ], ) -def test_redsvm_fit_hyperparameters_validation( - X, y, param_name, invalid_value, error_msg -): - """Test that hyperparameters are validated.""" +def test_redsvm_hyperparameter_value_validation(X, y, param_name, invalid_value): + """Test that REDSVM raises ValueError for invalid of hyperparameters.""" classifier = REDSVM(**{param_name: invalid_value}) - with pytest.raises(ValueError, match=error_msg): - model = classifier.fit(X, y) - assert model is None, "The REDSVM fit method doesnt return Null on error" + with pytest.raises(ValueError, match=rf"The '{param_name}' parameter.*"): + classifier.fit(X, y) + + +@pytest.mark.parametrize( + "param_name, invalid_value", + [ + ("C", "high"), + ("kernel", 5), + ("degree", 2.5), + ("coef0", "bias"), + ("shrinking", "yes"), + ("tol", "tight"), + ("cache_size", "big"), + ], +) +def test_redsvm_hyperparameter_type_validation(X, y, param_name, invalid_value): + """Test that REDSVM raises ValueError for invalid types of hyperparameters.""" + classifier = REDSVM(**{param_name: invalid_value}) + + with pytest.raises(ValueError, match=rf"The '{param_name}' parameter.*"): + classifier.fit(X, y) def test_redsvm_fit_input_validation(X, y): diff --git a/orca_python/classifiers/tests/test_svorex.py b/orca_python/classifiers/tests/test_svorex.py index 2178343..5aec800 100644 --- a/orca_python/classifiers/tests/test_svorex.py +++ b/orca_python/classifiers/tests/test_svorex.py @@ -48,22 +48,41 @@ def test_svorex_predict_matches_expected(kernel, expected_file): @pytest.mark.parametrize( - "params, error_msg", + "param_name, invalid_value", [ - ({"tol": 0}, "- T is invalid"), - ({"C": 0}, "- C is invalid"), - ({"kappa": 0}, "- K is invalid"), - ({"kernel": "poly", "degree": 0}, "- P is invalid"), - ({"kappa": -1}, "-1 is invalid"), + ("C", 0), + ("C", -1), + ("degree", -1), + ("tol", 0), + ("tol", -1e-5), + ("kernel", "unknown"), + ("kappa", -1), ], ) -def test_svorex_fit_hyperparameters_validation(X, y, params, error_msg): - """Test that hyperparameters are validated.""" - classifier = SVOREX(**params) +def test_svorex_hyperparameter_value_validation(X, y, param_name, invalid_value): + """Test that SVOREX raises ValueError for invalid of hyperparameters.""" + classifier = SVOREX(**{param_name: invalid_value}) - with pytest.raises(ValueError, match=error_msg): - model = classifier.fit(X, y) - assert model is None, "The SVOREX fit method doesnt return Null on error" + with pytest.raises(ValueError, match=rf"The '{param_name}' parameter.*"): + classifier.fit(X, y) + + +@pytest.mark.parametrize( + "param_name, invalid_value", + [ + ("C", "high"), + ("kernel", 5), + ("degree", 2.5), + ("tol", "tight"), + ("kappa", "low"), + ], +) +def test_svorex_hyperparameter_type_validation(X, y, param_name, invalid_value): + """Test that SVOREX raises ValueError for invalid types of hyperparameters.""" + classifier = SVOREX(**{param_name: invalid_value}) + + with pytest.raises(ValueError, match=rf"The '{param_name}' parameter.*"): + classifier.fit(X, y) def test_svorex_fit_input_validation(X, y):