diff --git a/Dockerfile b/Dockerfile index 0f2e036..6f73a45 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,8 @@ ARG PYTHON_VERSION=3.10 -FROM python:${PYTHON_VERSION} +FROM python:${PYTHON_VERSION}-bookworm ARG HLINK_EXTRAS=dev -RUN apt-get update && apt-get install default-jre-headless -y +RUN apt-get update && apt-get install openjdk-17-jre-headless -y RUN mkdir /hlink WORKDIR /hlink diff --git a/hlink/linking/core/classifier.py b/hlink/linking/core/classifier.py index b58780a..9310c8e 100644 --- a/hlink/linking/core/classifier.py +++ b/hlink/linking/core/classifier.py @@ -30,6 +30,17 @@ _xgboost_available = True +def _ensure_seeded(params: dict[str, Any]) -> dict[str, Any]: + """ + Ensure that the given dictionary of parameters has a "seed" parameter. + + This is useful for making results reproducible across different linking + runs. If the user doesn't set the "seed" parameter, then this function sets + it to 2133 (which is just as good as any other number, I suppose). + """ + return {"seed": 2133, **params} + + def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str): """Given a model type and hyper-parameters for the model, return a classifier of that type with those hyper-parameters, along with a @@ -60,11 +71,11 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str): post_transformer = SQLTransformer(statement="SELECT * FROM __THIS__") features_vector = "features_vector" if model_type == "random_forest": + params = _ensure_seeded(params) classifier = RandomForestClassifier( **params, labelCol=dep_var, featuresCol=features_vector, - seed=2133, probabilityCol="probability_array", ) post_transformer = SQLTransformer( @@ -93,23 +104,23 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str): ) elif model_type == "decision_tree": + params = _ensure_seeded(params) classifier = DecisionTreeClassifier( **params, featuresCol=features_vector, labelCol=dep_var, probabilityCol="probability_array", - seed=2133, ) post_transformer = SQLTransformer( statement="SELECT *, parseProbVector(probability_array, 1) as probability FROM __THIS__" ) elif model_type == "gradient_boosted_trees": + params = _ensure_seeded(params) classifier = GBTClassifier( **params, featuresCol=features_vector, labelCol=dep_var, - seed=2133, ) post_transformer = ( hlink.linking.transformers.rename_prob_column.RenameProbColumn() @@ -122,6 +133,8 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str): "its dependencies. Try installing hlink with the lightgbm extra: " "\n\n pip install hlink[lightgbm]" ) + + params = _ensure_seeded(params) classifier = synapse.ml.lightgbm.LightGBMClassifier( **params, featuresCol=features_vector, @@ -138,6 +151,8 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str): "the xgboost library and its dependencies. Try installing hlink with " "the xgboost extra:\n\n pip install hlink[xgboost]" ) + + params = _ensure_seeded(params) classifier = xgboost.spark.SparkXGBClassifier( **params, features_col=features_vector, diff --git a/hlink/tests/core/classifier_test.py b/hlink/tests/core/classifier_test.py index e95b878..e47a7e9 100644 --- a/hlink/tests/core/classifier_test.py +++ b/hlink/tests/core/classifier_test.py @@ -3,6 +3,8 @@ # in this project's top-level directory, and also on-line at: # https://github.com/ipums/hlink +import pytest + from hlink.linking.core.classifier import choose_classifier from hlink.tests.markers import requires_lightgbm, requires_xgboost @@ -30,3 +32,17 @@ def test_choose_classifier_supports_xgboost(): } classifier, _post_transformer = choose_classifier("xgboost", params, "match") assert classifier.getLabelCol() == "match" + + +@pytest.mark.parametrize( + "classifier", ["random_forest", "decision_tree", "gradient_boosted_trees"] +) +def test_choose_classifier_can_set_seed_in_params(spark, classifier) -> None: + """ + Ensure that you can pass a "seed" parameter to the classifier. This used to + cause an error because of manual handling of the seed parameter. See GitHub + Issue #221. + """ + params = {"seed": 151015} + classifier, _post_transformer = choose_classifier(classifier, params, "match") + assert classifier.getSeed() == 151015