Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ stf-download --species "Homo sapiens" --local_dir hest_data
We provide presets for baseline models and scaled versions of the SpatialTranscriptFormer.

```bash
# Recommended: Run the Interaction model with 4 transformer layers
python scripts/run_preset.py --preset stf_interaction_l4
# Recommended: Run the Interaction model (Small)
python scripts/run_preset.py --preset stf_small

# Run the lightweight 2-layer version
python scripts/run_preset.py --preset stf_interaction_l2
# Run the lightweight Tiny version
python scripts/run_preset.py --preset stf_tiny

# Run baselines
python scripts/run_preset.py --preset he2rna_baseline
Expand Down
46 changes: 0 additions & 46 deletions docs/LATENT_DISCOVERY.md

This file was deleted.

22 changes: 11 additions & 11 deletions docs/MODELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,8 @@ The SpatialTranscriptFormer models the **interaction between biological pathways
By default, the model operates in **Full Interaction** mode where all four information flows are active. Users can selectively disable any combination using the `--interactions` flag to explore architectural variants:

```bash
# Default: Full Interaction (all quadrants enabled)
--interactions p2p p2h h2p h2h

# Pathway Bottleneck: block H↔H to force all inter-patch
# communication through the pathway bottleneck
--interactions p2p p2h h2p
# Default: Small Interaction (CTransPath, 4 layers)
python scripts/run_preset.py --preset stf_small
```

> [!TIP]
Expand All @@ -53,7 +49,7 @@ Three additional design principles support these interactions:

- **Biological Initialisation** — The gene reconstruction weights are initialised from MSigDB Hallmark gene sets, providing a biologically-grounded starting point that the model refines during training.

### 2.2 Spatial Learning
## 2.2 Spatial Learning

The spatial relationships of gene expression are central to this model. It is not sufficient to predict correct expression magnitudes at each spot independently — the model must capture **where** on the tissue pathways are active and how that spatial pattern varies across the slide. Two mechanisms enforce this:

Expand Down Expand Up @@ -218,14 +214,19 @@ The model outputs these parameters, and the loss computes the negative log-likel

To prevent bottleneck collapse and provide a direct gradient signal to the pathway tokens, we use the `AuxiliaryPathwayLoss`. This loss compares the model's internal pathway scores against "ground truth" pathway activations computed from the gene expression targets via MSigDB membership.

To prevent highly-expressed housekeeping genes from dominating the pathway's spatial pattern, the ground-truth targets are computed using **Z-score spatial normalization**:

1. Every gene's spatial expression pattern is standardized (mean=0, variance=1) across the tissue slide.
2. The normalized genes are projected onto the binary MSigDB pathway matrix.
3. The resulting pathway scores are **mean-aggregated** (divided by the number of known member genes in each pathway) rather than raw-summed.

This ensures every gene—including critical but lowly-expressed transcription factors—gets an equal vote in determining where a pathway is active.

The total objective becomes:
$$\mathcal{L} = \mathcal{L}_{gene} + \lambda_{aux} (1 - \text{PCC}(\text{pathway\_scores}, \text{target\_pathways}))$$

The `--log-transform` flag applies `log1p` to targets, mitigating the heavy-tailed gene expression distribution where housekeeping genes dominate MSE.

The full training objective with pathway sparsity regularisation:
$$\mathcal{L} = \mathcal{L}_{task} + \lambda \|W_{recon}\|_1$$

---

## 5. CLI Flags (Model Configuration)
Expand All @@ -239,7 +240,6 @@ $$\mathcal{L} = \mathcal{L}_{task} + \lambda \|W_{recon}\|_1$$
| `--n-layers` | 2 | Transformer layers (minimum 2) |
| `--num-pathways` | 50 | Number of pathway bottleneck tokens |
| `--pathway-init` | off | Initialize gene_reconstructor from MSigDB |
| `--sparsity-lambda` | 0.0 | L1 regularisation on reconstruction weights |
| `--loss mse_pcc` | `mse` | Loss function (`mse`, `pcc`, `mse_pcc`, `zinb`) |
| `--pcc-weight` | 1.0 | Weight for PCC term in composite loss |
| `--pathway-loss-weight` | 0.0 | Weight for auxiliary pathway loss ($\lambda_{aux}$) |
Expand Down
22 changes: 5 additions & 17 deletions docs/PATHWAY_MAPPING.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,11 @@ In this mode, the network receives direct supervision on its pathway tokens, gui
1. **Interaction**: Learnable pathway tokens $P$ interact with Histology patch features $H$ via self-attention (e.g., $p2h$, $h2p$).
2. **Activation**: Pathway scores $S \in \mathbb{R}^P$ are computed using a learnable temperature-scaled cosine similarity between the pathway tokens and image patch tokens.
3. **Gene Reconstruction**: $\hat{y} = S \cdot \mathbf{W}_{recon} + b$, where $\mathbf{W}_{recon}$ is initialized using the binary pathway membership matrix $M$.
- **MTL Auxiliary Loss**: To prevent standard bottleneck collapse, an explicit auxiliary loss bridges the spatial representations directly to biological data. The pathway scores $S$ are supervised against a pathway ground truth ($Y_{genes} \cdot M^T$) using a Pearson Correlation Coefficient (PCC) loss.
$$L_{total} = L_{gene} + \lambda_{pathway} (1 - PCC(S, Y_{genes} \cdot M^T))$$
- **Benefit**: The model is forced to explicitly align its internal interaction tokens with concrete biological pathways, granting direct interpretability.

#### 2. Data-Driven Discovery (Latent Projection)

In the absence of a biological prior, the model can learn its own "latent pathways".

- **Implementation**: $\mathbf{W}_{recon}$ is randomly initialized and the auxiliary pathway loss is disabled.
- **Sparsity Constraint**: We apply an L1 penalty to force the model to identify "canonical" sparse gene sets: $L_{total} = L_{gene} + \lambda_{sparsity} \|\mathbf{W}_{recon}\|_1$.
- **Benefit**: Can discover novel spatial-transcriptomic relationships that aren't yet captured in curated databases.
- **MTL Auxiliary Loss**: To prevent standard bottleneck collapse, an explicit auxiliary loss bridges the spatial representations directly to biological data. The pathway scores $S$ are supervised against a pathway ground truth using a Pearson Correlation Coefficient (PCC) loss.
- To prevent highly expressed housekeeping genes dominating the signal, the raw spatial gene counts ($Y_{genes}$) are first **spatially Z-score normalized** ($Z_{genes}$).
- These are then projected onto the pathway matrix and mean-aggregated by member count ($C$):
$$L_{total} = L_{gene} + \lambda_{pathway} (1 - PCC(S, \frac{Z_{genes} \cdot M^T}{C}))$$
- **Benefit**: The model is forced to explicitly align its internal interaction tokens with concrete biological pathways, granting direct interpretability where every gene gets an equal vote.

## 3. Generalizing to HEST1k Tissues

Expand Down Expand Up @@ -73,18 +67,12 @@ By supplying these functional groupings via `--custom-gmt`, the model's MTL proc
- GMT file cached in `.cache/` after first download.
- **Custom Pathway Definitions** (`--custom-gmt` flag): Users can override the default Hallmarks by providing a URL or local path to a `.gmt` file, enabling custom database integrations (e.g., KEGG, Reactome, or highly specific tissue masks).

- **Sparsity Regularization** (`--sparsity-lambda` flag): L1 penalty on `gene_reconstructor` weights to encourage pathway-like groupings when using data-driven (random) initialization.

### Usage

```bash
# With biological initialization (50 MSigDB Hallmarks)
python -m spatial_transcript_former.train \
--model interaction --pathway-init ...

# With data-driven pathways + sparsity
python -m spatial_transcript_former.train \
--model interaction --num-pathways 50 --sparsity-lambda 0.01 ...
```

- **Spatial Pathway Maps**: Visualize pathway activations as spatial heatmaps overlaid on histology using `stf-predict`. See the [README](../README.md) for inference instructions.
Expand Down
46 changes: 12 additions & 34 deletions docs/TRAINING_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,44 +126,22 @@ python -m spatial_transcript_former.train \

> **Note**: `--pathway-init` overrides `--num-pathways` to 50 (the number of Hallmark gene sets). The GMT file is cached in `.cache/` after first download.

### Data-Driven Discovery (Latent Pathways)
### Recommended: Using Presets

To allow the model to discover its own spatial-transcriptomic relationships without biological priors, omit `--pathway-init` and apply sparsity regularization (`--sparsity-lambda`). This aims to force the model to identify "canonical" sparse gene sets.
For most cases, it is recommended to use the provided presets:

```bash
python -m spatial_transcript_former.train \
--data-dir A:\hest_data \
--model interaction \
--backbone ctranspath \
--use-nystrom \
--num-pathways 50 \
--sparsity-lambda 0.01 \
--precomputed \
--whole-slide \
--use-amp \
--log-transform \
--epochs 100
```

> **Note**: Without `--pathway-init`, the model disables the `AuxiliaryPathwayLoss` and relies entirely on the main reconstruction objectives and the L1 sparsity penalty. (I am yet to obtain results with this method)...
# Tiny (2 layers, 256 dim)
python scripts/run_preset.py --preset stf_tiny

### Robust Counting: ZINB + Auxiliary Loss
# Small (4 layers, 384 dim) - Recommended
python scripts/run_preset.py --preset stf_small

For raw count data with high sparsity, using the ZINB distribution and auxiliary pathway supervision is recommended.
# Medium (6 layers, 512 dim)
python scripts/run_preset.py --preset stf_medium

```bash
python -m spatial_transcript_former.train \
--data-dir A:\hest_data \
--model interaction \
--backbone ctranspath \
--pathway-init \
--loss zinb \
--pathway-loss-weight 0.5 \
--lr 5e-5 \
--batch-size 4 \
--whole-slide \
--precomputed \
--epochs 200
# Large (12 layers, 768 dim)
python scripts/run_preset.py --preset stf_large
```

### Choosing Interaction Modes
Expand Down Expand Up @@ -201,7 +179,7 @@ Submit with:
sbatch hpc/array_train.slurm
```

### Collecting Results
### Collecting Results (Currently broken!)

After experiments complete, aggregate all `results_summary.json` files into a comparison table:

Expand Down Expand Up @@ -243,8 +221,8 @@ python -m spatial_transcript_former.train --resume --output-dir runs/my_experime
| `--feature-dir` | Explicit path to precomputed features directory. | Overrides auto-detection. |
| `--loss` | Loss function: `mse`, `pcc`, `mse_pcc`, `zinb`. | `mse_pcc` or `zinb` recommended. |
| `--pathway-loss-weight` | Weight ($\lambda$) for auxiliary pathway supervision. | Set `0.5` or `1.0` with `interaction` model. |
| `--sparsity-lambda` | L1 regularization weight for discovering latent pathways. | Use `0.01` when `--pathway-init` is NOT used. |
| `--interactions` | Enabled attention quadrants: `p2p`, `p2h`, `h2p`, `h2h`. | Default: `all` (Full Interaction). |
| `--plot-pathways-list` | Names of explicitly requested pathways to visualize as heatmaps during periodic validation. | Use with `--plot-pathways`. e.g. `HYPOXIA ANGIOGENESIS` |
| `--log-transform` | Apply log1p to gene expression targets. | Recommended for raw count data. |
| `--num-genes` | Number of HVGs to predict (default: 1000). | Match your `global_genes.json`. |
| `--mask-radius` | Euclidean distance for spatial attention gating. | Usually between 200 and 800. |
Expand Down
95 changes: 95 additions & 0 deletions scripts/analyze_expression_variance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import argparse
import numpy as np
import h5py
import matplotlib.pyplot as plt
import pandas as pd
import json


def analyze_sample(h5ad_path):
print(f"Analyzing {h5ad_path}...")

with h5py.File(h5ad_path, "r") as f:
# Check standard AnnData structure
if "X" in f:
if isinstance(f["X"], h5py.Group):
# Sparse format (CSR/CSC)
data_group = f["X"]["data"][:]
n_cells = (
f["obs"]["_index"].shape[0]
if "_index" in f["obs"]
else len(f["obs"])
)
n_genes = (
f["var"]["_index"].shape[0]
if "_index" in f["var"]
else len(f["var"])
)

print(f"Data is sparse, shape: ({n_cells}, {n_genes})")
print(f"Non-zero elements: {len(data_group)}")

# Analyze non-zero elements
mean_val = np.mean(data_group)
max_val = np.max(data_group)
min_val = np.min(data_group)

print(f"Non-zero Mean: {mean_val:.4f}")
print(f"Max Expression: {max_val:.4f}")
print(f"Min Expression: {min_val:.4f}")

else:
# Dense array
X = f["X"][:]
print(f"Data is dense, shape: {X.shape}")

# Basic stats
mean_exp = np.mean(X, axis=0) # per gene mean
var_exp = np.var(X, axis=0) # per gene variance
max_exp = np.max(X, axis=0)

sparsity = np.sum(X == 0) / X.size
print(f"Overall Sparsity (zeros): {sparsity:.2%}")

print(
f"Gene Mean Range: {np.min(mean_exp):.4f} to {np.max(mean_exp):.4f}"
)
print(f"Gene Var Range: {np.min(var_exp):.4f} to {np.max(var_exp):.4f}")
print(f"Overall Max Expression: {np.max(max_exp):.4f}")

# Check for extreme differences in variance
var_ratio = np.max(var_exp) / (np.min(var_exp) + 1e-8)
print(f"Ratio of max/min gene variance: {var_ratio:.4e}")

return {
"sparsity": sparsity,
"var_ratio": var_ratio,
"max_exp": np.max(max_exp),
}


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-dir",
type=str,
default="A:\\hest_data",
help="Path to HEST data directory",
)
args = parser.parse_args()

st_dir = os.path.join(args.data_dir, "st")
if not os.path.exists(st_dir):
print(f"Error: Directory not found: {st_dir}")
exit(1)

# Get a few random samples
samples = [f for f in os.listdir(st_dir) if f.endswith(".h5ad")]
if not samples:
print(f"No .h5ad files found in {st_dir}")

# Analyze the first couple of samples
for sample in samples[:3]:
analyze_sample(os.path.join(st_dir, sample))
print("-" * 50)
Loading