Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 23.3.0 # Replace by any tag/version: https://github.com/psf/black/tags
rev: 25.1.0 # Replace by any tag/version: https://github.com/psf/black/tags
hooks:
- id: black
language_version: python3
190 changes: 102 additions & 88 deletions muon/_core/plot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Iterable, List, Optional, Sequence, Union
from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Sequence, Union, Mapping
import warnings

from matplotlib.axes import Axes
Expand All @@ -19,12 +20,19 @@


def scatter(
data: Union[AnnData, MuData],
x: Optional[str] = None,
y: Optional[str] = None,
color: Optional[Union[str, Sequence[str]]] = None,
use_raw: Optional[bool] = None,
layers: Optional[Union[str, Sequence[str]]] = None,
data: AnnData | MuData,
x: str | None = None,
y: str | None = None,
color: str | Sequence[str] | None = None,
use_raw: bool | Mapping[str, bool] = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But how do you indicate this breaking change to users? I'm not familiar with the release process or your promises. I guess you;re not on a major version >=1 yet so it doesn't matter. Is there a changelog?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good question :) I don't think we ever defined a formal release process, but perhaps @gtca knows more. We do have a changelog in the docs directory, I guess I can add this there.

layers: (
str
| tuple[str, str, str]
| Mapping[str, str]
| tuple[Mapping[str, str], Mapping[str, str], Mapping[str, str]]
| None
) = None,
gene_symbols: str | Mapping[str, str | None] | None = None,
**kwargs,
):
"""
Expand All @@ -36,42 +44,56 @@ def scatter(

Parameters
----------
data : Union[AnnData, MuData]
data
MuData or AnnData object
x : Optional[str]
x
x coordinate
y : Optional[str]
y
y coordinate
color : Optional[Union[str, Sequence[str]]], optional (default: None)
color
Keys or a single key for variables or annotations of observations (.obs columns),
or a hex colour specification.
use_raw : Optional[bool], optional (default: None)
use_raw
Use `.raw` attribute of the modality where a feature (from `color`) is derived from.
If `None`, defaults to `True` if `.raw` is present and a valid `layer` is not provided.
layers : Optional[Union[str, Sequence[str]]], optional (default: None)
If a dictionary is given, it must have one entry for each modality.
layers
Names of the layers where x, y, and color come from.
No layer is used by default. A single layer value will be expanded to [layer, layer, layer].
If a dictionary is given, it must have one entry for each modality.
gene_symbols
Column of `.var` to search for `color` in.
If a dictionary is given, it must have one entry for each modality.
"""
if isinstance(data, AnnData):
localvars = locals()
for arg in ("use_raw", "layers", "gene_symbols"):
if isinstance(localvars[arg], Mapping):
raise ValueError(
f"`{arg}` can only be a dictionary if `data` is a `MuData` object."
)
return sc.pl.scatter(data, x=x, y=y, color=color, use_raw=use_raw, layers=layers, **kwargs)

if isinstance(layers, str) or layers is None:
layers = [layers, layers, layers]
if isinstance(layers, str) or isinstance(layers, Mapping) or layers is None:
layers = (layers, layers, layers)

obs = pd.DataFrame(
{
x: _get_values(data, x, use_raw=use_raw, layer=layers[0]),
y: _get_values(data, y, use_raw=use_raw, layer=layers[1]),
x: _get_values(data, x, use_raw=use_raw, layer=layers[0], gene_symbols=gene_symbols),
y: _get_values(data, y, use_raw=use_raw, layer=layers[1], gene_symbols=gene_symbols),
}
)
obs.index = data.obs_names
if color is not None:
# Workaround for scanpy#311, scanpy#1497
if isinstance(color, str):
color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2])
color_obs = _get_values(
data, color, use_raw=use_raw, layer=layers[2], gene_symbols=gene_symbols
)
color_obs = pd.DataFrame({color: color_obs})
else:
color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2])
color_obs = _get_values(
data, color, use_raw=use_raw, layer=layers[2], gene_symbols=gene_symbols
)

color_obs.index = data.obs_names
obs = pd.concat([obs, color_obs], axis=1, ignore_index=False)
Expand All @@ -96,11 +118,12 @@ def scatter(


def embedding(
data: Union[AnnData, MuData],
data: AnnData | MuData,
basis: str,
color: Optional[Union[str, Sequence[str]]] = None,
use_raw: Optional[bool] = None,
layer: Optional[str] = None,
color: str | Sequence[str] | None = None,
use_raw: bool | Mapping[str, bool] = False,
layer: str | Mapping[str, str | None] | None = None,
gene_symbols: str | Mapping[str, str | None] | None = None,
**kwargs,
):
"""
Expand All @@ -114,24 +137,38 @@ def embedding(

Parameters
----------
data : Union[AnnData, MuData]
data
MuData or AnnData object
basis : str
basis
Name of the `obsm` basis to use
color : Optional[Union[str, typing.Sequence[str]]], optional (default: None)
color
Keys for variables or annotations of observations (.obs columns).
Can be from any modality.
use_raw : Optional[bool], optional (default: None)
use_raw
Use `.raw` attribute of the modality where a feature (from `color`) is derived from.
If `None`, defaults to `True` if `.raw` is present and a valid `layer` is not provided.
layer : Optional[str], optional (default: None)
If a dictionary is given, it must have one entry for each modality.
layer
Name of the layer in the modality where a feature (from `color`) is derived from.
No layer is used by default. If a valid `layer` is provided, this takes precedence
over `use_raw=True`.
If a dictionary is given, it must have one entry for each modality.
gene_symbols
Column of `.var` to search for `color` in.
If a dictionary is given, it must have one entry for each modality.
"""
if isinstance(data, AnnData):
localvars = locals()
for arg in ("use_raw", "layer", "gene_symbols"):
if isinstance(localvars[arg], Mapping):
raise ValueError(
f"`{arg}` can only be a dictionary if `data` is a `MuData` object."
)
return sc.pl.embedding(
data, basis=basis, color=color, use_raw=use_raw, layer=layer, **kwargs
data,
basis=basis,
color=color,
use_raw=use_raw,
layer=layer,
gene_symbols=gene_symbols,
**kwargs,
)

# `data` is MuData
Expand All @@ -145,8 +182,8 @@ def embedding(
# basis is not a joint embedding
try:
mod, basis_mod = basis.split(":")
except ValueError:
raise ValueError(f"Basis {basis} is not present in the MuData object (.obsm)")
except ValueError as e:
raise ValueError(f"Basis {basis} is not present in the MuData object (.obsm)") from e

if mod not in data.mod:
raise ValueError(
Expand All @@ -168,6 +205,16 @@ def embedding(

obs = data.obs.loc[adata.obs.index.values]

if not isinstance(use_raw, Mapping):
use_rawd = use_raw
use_raw = defaultdict(lambda: use_rawd)
if not isinstance(layer, Mapping):
layerd = layer
layer = defaultdict(lambda: layerd)
if not isinstance(gene_symbols, Mapping):
gene_symbolsd = gene_symbols
gene_symbols = defaultdict(lambda: gene_symbolsd)

if color is None:
ad = AnnData(obs=obs, obsm=adata.obsm, obsp=adata.obsp)
return sc.pl.embedding(ad, basis=basis_mod, **kwargs)
Expand All @@ -180,77 +227,44 @@ def embedding(
else:
raise TypeError("Expected color to be a string or an iterable.")

varidx = {}
for m, mod in data.mod.items():
if layer[m] is not None and use_raw[m]:
raise ValueError("use_raw cannot be True when a layer is specified.")

var = mod.var if not use_raw[m] else mod.raw.var
varidx[m] = var.index if gene_symbols[m] is None else pd.Index(var[gene_symbols[m]])

# Fetch respective features
if not all([key in obs for key in keys]):
# {'rna': [True, False], 'prot': [False, True]}
keys_in_mod = {m: [key in data.mod[m].var_names for key in keys] for m in data.mod}

# .raw slots might have exclusive var_names
if use_raw is None or use_raw:
for i, k in enumerate(keys):
for m in data.mod:
if keys_in_mod[m][i] == False and data.mod[m].raw is not None:
keys_in_mod[m][i] = k in data.mod[m].raw.var_names
keys_in_mod = {m: [key in varidx[m] for key in keys] for m in data.mod}

# e.g. color="rna:CD8A" - especially relevant for mdata.axis == -1
mod_key_modifier: dict[str, str] = dict()
for i, k in enumerate(keys):
mod_key_modifier[k] = k
for m in data.mod:
for m, mod in data.mod.items():
if not keys_in_mod[m][i]:
k_clean = k
if k.startswith(f"{m}:"):
k_clean = k.split(":", 1)[1]

keys_in_mod[m][i] = k_clean in data.mod[m].var_names
keys_in_mod[m][i] = k_clean in varidx[m]
if keys_in_mod[m][i]:
mod_key_modifier[k] = k_clean
if use_raw is None or use_raw:
if keys_in_mod[m][i] == False and data.mod[m].raw is not None:
keys_in_mod[m][i] = k_clean in data.mod[m].raw.var_names

for m in data.mod:
for m, mod in data.mod.items():
if np.sum(keys_in_mod[m]) > 0:
mod_keys = np.array(keys)[keys_in_mod[m]]
mod_keys = np.array([mod_key_modifier[k] for k in mod_keys])

if use_raw is None or use_raw:
if data.mod[m].raw is not None:
keysidx = data.mod[m].raw.var.index.get_indexer_for(mod_keys)
fmod_adata = AnnData(
X=data.mod[m].raw.X[:, keysidx],
var=pd.DataFrame(index=mod_keys),
obs=data.mod[m].obs,
)
else:
if use_raw:
warnings.warn(
f"Attibute .raw is None for the modality {m}, using .X instead"
)
fmod_adata = data.mod[m][:, mod_keys]
else:
fmod_adata = data.mod[m][:, mod_keys]

if layer is not None:
if isinstance(layer, Dict):
m_layer = layer.get(m, None)
if m_layer is not None:
x = data.mod[m][:, mod_keys].layers[m_layer]
fmod_adata.X = x.todense() if issparse(x) else x
if use_raw:
warnings.warn(f"Layer='{layer}' superseded use_raw={use_raw}")
elif layer in data.mod[m].layers:
x = data.mod[m][:, mod_keys].layers[layer]
fmod_adata.X = x.todense() if issparse(x) else x
if use_raw:
warnings.warn(f"Layer='{layer}' superseded use_raw={use_raw}")
else:
warnings.warn(
f"Layer {layer} is not present for the modality {m}, using count matrix instead"
)
x = fmod_adata.X.toarray() if issparse(fmod_adata.X) else fmod_adata.X
mod_keys = [mod_key_modifier[k] for i, k in enumerate(keys) if keys_in_mod[m][i]]
obs = obs.join(
pd.DataFrame(x, columns=mod_keys, index=fmod_adata.obs_names),
sc.get.obs_df(
mod,
keys=mod_keys,
layer=layer[m],
use_raw=use_raw[m],
gene_symbols=gene_symbols[m],
),
how="left",
)

Expand Down
Loading