diff --git a/src/inference.py b/src/inference.py index 407dc13..c28b78b 100644 --- a/src/inference.py +++ b/src/inference.py @@ -129,8 +129,9 @@ def create_cell_embeddings_torch(expression_matrix, embedding_matrix, device="cp # Perform sparse matrix multiplication cell_embeddings = torch.sparse.mm(expression_matrix, embedding_matrix.T) - # Normalize the cell embeddings + # Normalize the cell embeddings, avoiding division by zero norms = torch.norm(cell_embeddings, dim=1, keepdim=True) + norms = torch.where(norms == 0, torch.ones_like(norms), norms) cell_embeddings = cell_embeddings / norms return cell_embeddings @@ -154,8 +155,9 @@ def create_cell_embeddings(expression_matrix, embedding_matrix, valid_indices): # Perform the matrix multiplication (n_cells x n_embedding_dimensions) cell_embeddings = filtered_expression @ embedding_matrix.T - # Normalize the cell embeddings + # Normalize the cell embeddings, avoiding division by zero norms = np.linalg.norm(cell_embeddings, axis=1, keepdims=True) + norms[norms == 0] = 1 cell_embeddings = cell_embeddings / norms return cell_embeddings diff --git a/test/test_inference.py b/test/test_inference.py index 327559c..b43c2b6 100644 --- a/test/test_inference.py +++ b/test/test_inference.py @@ -101,6 +101,18 @@ def test_create_cell_embeddings(sample_data): np.testing.assert_array_almost_equal(norms, np.ones(3)) +def test_create_cell_embeddings_with_zero_row(sample_data): + embedding_matrix, valid_indices = create_embedding_matrix( + sample_data["merged_embeddings"], sample_data["major_gene_ids"]) + + zero_expr = sparse.csr_matrix([[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]]) + cell_embeddings = create_cell_embeddings(zero_expr, embedding_matrix, + valid_indices) + + assert np.all(cell_embeddings[0] == 0) + assert not np.any(np.isnan(cell_embeddings)) + + if _torch_available: def test_create_cell_embeddings_torch(sample_data): @@ -142,6 +154,25 @@ def test_create_cell_embeddings_torch(sample_data): assert torch.allclose(cell_embeddings, torch.tensor(numpy_embeddings, dtype=torch.float32)) + def test_create_cell_embeddings_torch_zero_row(sample_data): + embedding_matrix, valid_indices = create_embedding_matrix_torch( + sample_data["merged_embeddings"], sample_data["major_gene_ids"]) + + zero_expr = sparse.csr_matrix([[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]]) + filtered_expression = zero_expr[:, valid_indices] + expression_tensor = torch.sparse_csr_tensor( + torch.LongTensor(filtered_expression.indptr), + torch.LongTensor(filtered_expression.indices), + torch.FloatTensor(filtered_expression.data), + size=filtered_expression.shape, + ) + + cell_embeddings = create_cell_embeddings_torch(expression_tensor, + embedding_matrix) + + assert torch.all(cell_embeddings[0] == 0) + assert not torch.isnan(cell_embeddings).any() + def test_device_handling(): if torch.cuda.is_available(): device = torch.device("cuda")