diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index c794bee..5c44506 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -23,6 +23,6 @@ jobs: - name: "Black: Code Formatting 👮🏽‍♀️" run: black --diff . && black -v --check . - name: install requirements - run: pip install -r environments/requirements.txt && pip install -r environments/requirements_genetic.txt + run: pip install -r environments/requirements.txt - name: run tests run: PYTHONPATH=. pytest diff --git a/pybalance/utils/matching_data.py b/pybalance/utils/matching_data.py index 17323fc..dfe7553 100644 --- a/pybalance/utils/matching_data.py +++ b/pybalance/utils/matching_data.py @@ -330,13 +330,13 @@ def describe_categoric(self, normalize=True) -> pd.DataFrame: for cat in self.headers["categoric"]: tmp = ( self.data.reset_index() - .groupby(["population", cat]) + .groupby([self.population_col, cat]) .count()[["index"]] .reset_index() ) tmp.loc[:, "feature"] = cat tmp = tmp.pivot( - index=["feature", cat], columns=["population"], values=["index"] + index=["feature", cat], columns=[self.population_col], values=["index"] ) tmp.columns = [c[1] for c in tmp.columns] tmp.index.names = ["feature", "value"] @@ -345,8 +345,9 @@ def describe_categoric(self, normalize=True) -> pd.DataFrame: out = pd.concat(out).fillna(0) if normalize: + out = out.astype(float) for c in counts.columns: - out.loc[:, c] = out[c] / counts.iloc[0][c] + out.loc[:, c] = (out[c].values / counts.iloc[0][c]).astype(float) else: out = out.astype(int) diff --git a/pybalance/utils/tests/test_matching_data.py b/pybalance/utils/tests/test_matching_data.py index af49ea1..1cb59a7 100644 --- a/pybalance/utils/tests/test_matching_data.py +++ b/pybalance/utils/tests/test_matching_data.py @@ -222,3 +222,12 @@ def test_specified_headers(): m = MatchingData(df, headers=headers) assert m.headers["categoric"] == [] assert m.headers["numeric"] == ["cat3"] + + +def test_describe_with_custom_population_col(): + # Create a dataset with a custom population column "Z" + data = generate_toy_dataset().data + data.rename(columns={"population": "Z"}, inplace=True) + quantiles = [0, 0.25, 0.5, 0.75, 1] + m = MatchingData(data, population_col="Z") + m.describe(aggregations=[], quantiles=quantiles)