Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
ARG PYTHON_VERSION=3.10
FROM python:${PYTHON_VERSION}
FROM python:${PYTHON_VERSION}-bookworm
ARG HLINK_EXTRAS=dev

RUN apt-get update && apt-get install default-jre-headless -y
RUN apt-get update && apt-get install openjdk-17-jre-headless -y

RUN mkdir /hlink
WORKDIR /hlink
Expand Down
21 changes: 18 additions & 3 deletions hlink/linking/core/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@
_xgboost_available = True


def _ensure_seeded(params: dict[str, Any]) -> dict[str, Any]:
"""
Ensure that the given dictionary of parameters has a "seed" parameter.

This is useful for making results reproducible across different linking
runs. If the user doesn't set the "seed" parameter, then this function sets
it to 2133 (which is just as good as any other number, I suppose).
"""
return {"seed": 2133, **params}


def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str):
"""Given a model type and hyper-parameters for the model, return a
classifier of that type with those hyper-parameters, along with a
Expand Down Expand Up @@ -60,11 +71,11 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str):
post_transformer = SQLTransformer(statement="SELECT * FROM __THIS__")
features_vector = "features_vector"
if model_type == "random_forest":
params = _ensure_seeded(params)
classifier = RandomForestClassifier(
**params,
labelCol=dep_var,
featuresCol=features_vector,
seed=2133,
probabilityCol="probability_array",
)
post_transformer = SQLTransformer(
Expand Down Expand Up @@ -93,23 +104,23 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str):
)

elif model_type == "decision_tree":
params = _ensure_seeded(params)
classifier = DecisionTreeClassifier(
**params,
featuresCol=features_vector,
labelCol=dep_var,
probabilityCol="probability_array",
seed=2133,
)
post_transformer = SQLTransformer(
statement="SELECT *, parseProbVector(probability_array, 1) as probability FROM __THIS__"
)

elif model_type == "gradient_boosted_trees":
params = _ensure_seeded(params)
classifier = GBTClassifier(
**params,
featuresCol=features_vector,
labelCol=dep_var,
seed=2133,
)
post_transformer = (
hlink.linking.transformers.rename_prob_column.RenameProbColumn()
Expand All @@ -122,6 +133,8 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str):
"its dependencies. Try installing hlink with the lightgbm extra: "
"\n\n pip install hlink[lightgbm]"
)

params = _ensure_seeded(params)
classifier = synapse.ml.lightgbm.LightGBMClassifier(
**params,
featuresCol=features_vector,
Expand All @@ -138,6 +151,8 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str):
"the xgboost library and its dependencies. Try installing hlink with "
"the xgboost extra:\n\n pip install hlink[xgboost]"
)

params = _ensure_seeded(params)
classifier = xgboost.spark.SparkXGBClassifier(
**params,
features_col=features_vector,
Expand Down
16 changes: 16 additions & 0 deletions hlink/tests/core/classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# in this project's top-level directory, and also on-line at:
# https://github.com/ipums/hlink

import pytest

from hlink.linking.core.classifier import choose_classifier
from hlink.tests.markers import requires_lightgbm, requires_xgboost

Expand Down Expand Up @@ -30,3 +32,17 @@ def test_choose_classifier_supports_xgboost():
}
classifier, _post_transformer = choose_classifier("xgboost", params, "match")
assert classifier.getLabelCol() == "match"


@pytest.mark.parametrize(
"classifier", ["random_forest", "decision_tree", "gradient_boosted_trees"]
)
def test_choose_classifier_can_set_seed_in_params(spark, classifier) -> None:
"""
Ensure that you can pass a "seed" parameter to the classifier. This used to
cause an error because of manual handling of the seed parameter. See GitHub
Issue #221.
"""
params = {"seed": 151015}
classifier, _post_transformer = choose_classifier(classifier, params, "match")
assert classifier.getSeed() == 151015