From c7e0e65eb9eefaa7f14c0e0c791773802b81f60c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 25 Apr 2025 18:33:26 +0200 Subject: [PATCH 1/2] Fix quantized_embedding_norm undefined when normalize=False --- cebra/integrations/sklearn/helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cebra/integrations/sklearn/helpers.py b/cebra/integrations/sklearn/helpers.py index 2d2fc627..41e56003 100644 --- a/cebra/integrations/sklearn/helpers.py +++ b/cebra/integrations/sklearn/helpers.py @@ -155,6 +155,8 @@ def align_embeddings( quantized_sample / np.linalg.norm(quantized_sample, axis=0) for quantized_sample in quantized_embedding ] + quantized_embeddings.append(quantized_embedding_norm) + else: + quantized_embeddings.append(quantized_embedding) - quantized_embeddings.append(quantized_embedding_norm) return quantized_embeddings From d1c82b1b17257bb2c6a5bff720ed94ce8cd644dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:36:55 +0200 Subject: [PATCH 2/2] Add test --- tests/test_sklearn_metrics.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py index 4e765ba7..10c62453 100644 --- a/tests/test_sklearn_metrics.py +++ b/tests/test_sklearn_metrics.py @@ -27,6 +27,7 @@ import cebra import cebra.integrations.sklearn.cebra as cebra_sklearn_cebra +import cebra.integrations.sklearn.helpers as cebra_sklearn_helpers import cebra.integrations.sklearn.metrics as cebra_sklearn_metrics @@ -385,6 +386,36 @@ def test_sklearn_runs_consistency(): invalid_embeddings_runs, between="runs") +def test_align_embeddings(): + # Example data + np.random.seed(42) + embedding1 = np.random.uniform(0, 1, (10000, 4)) + embedding2 = np.random.uniform(0, 1, (10000, 10)) + embedding3 = np.random.uniform(0, 1, (8000, 6)) + embeddings_datasets = [embedding1, embedding2, embedding3] + + labels1 = np.random.uniform(0, 1, (10000,)) + labels2 = np.random.uniform(0, 1, (10000,)) + labels3 = np.random.uniform(0, 1, (8000,)) + labels_datasets = [labels1, labels2, labels3] + + embeddings = cebra_sklearn_helpers.align_embeddings( + embeddings=embeddings_datasets, + labels=labels_datasets, + normalize=False, + n_bins=100) + + normalized_embeddings = cebra_sklearn_helpers.align_embeddings( + embeddings=embeddings_datasets, + labels=labels_datasets, + normalize=True, + n_bins=100) + + assert len(embeddings) == len(embeddings_datasets) + assert len(normalized_embeddings) == len(embeddings_datasets) + assert len(embeddings) == len(normalized_embeddings) + + @pytest.mark.parametrize("seed", [42, 24, 10]) def test_goodness_of_fit_score(seed): """