diff --git a/README.md b/README.md index 4bfb038..bea1845 100644 --- a/README.md +++ b/README.md @@ -62,11 +62,11 @@ stf-download --species "Homo sapiens" --local_dir hest_data We provide presets for baseline models and scaled versions of the SpatialTranscriptFormer. ```bash -# Recommended: Run the Interaction model with 4 transformer layers -python scripts/run_preset.py --preset stf_interaction_l4 +# Recommended: Run the Interaction model (Small) +python scripts/run_preset.py --preset stf_small -# Run the lightweight 2-layer version -python scripts/run_preset.py --preset stf_interaction_l2 +# Run the lightweight Tiny version +python scripts/run_preset.py --preset stf_tiny # Run baselines python scripts/run_preset.py --preset he2rna_baseline diff --git a/docs/LATENT_DISCOVERY.md b/docs/LATENT_DISCOVERY.md deleted file mode 100644 index 35ed192..0000000 --- a/docs/LATENT_DISCOVERY.md +++ /dev/null @@ -1,46 +0,0 @@ -# Latent Pathway Discovery - -The `SpatialTranscriptFormer` allows for the unsupervised discovery of biological processes directly from data. This "Data-Driven" approach doesn't rely on existing pathway databases like KEGG or MSigDB. - -## Latent Pathway Discovery - -The **SpatialTranscriptFormer** (specifically the `SpatialTranscriptFormer` class) can be used to discover unsupervised pathways from spatial transcriptomics data. The Pathway Bottleneck Layer forces the model to compress all image-to-gene mappings through a set of intermediate "factors." - -- Each **factor** acts as a learned pathway. -- The **Gene Reconstructor** matrix determines which genes belong to which latent pathway. -- By looking at the highest weights in this matrix, you can "decode" what biological process the model has discovered. - -## Sparsity Regularization (L1) - -To make discovered pathways cleaner and more interpretable, we use L1 regularization. This pushes low-contribution gene weights to zero, ensuring each latent factor is associated with a small, cohesive set of genes. - -### Usage in Training - -You can enable sparsity regularization using the `--sparsity-lambda` argument in `train.py`: - -```bash -python src/spatial_transcript_former/train.py \ - --model interaction \ - --num-pathways 50 \ - --sparsity-lambda 0.001 \ - --data-dir A:/hest_data -``` - -## Interpreting Discovered Pathways - -After training, you can inspect the `gene_reconstructor.weight` matrix `(G x P)` to name your pathways: - -```python -# Pseudo-code for interpretation -weights = model.gene_reconstructor.weight.data # (GeneCount, PathwayCount) -for p in range(num_pathways): - top_indices = torch.topk(weights[:, p], k=10).indices - top_genes = [gene_names[i] for i in top_indices] - print(f"Learned Pathway {p} Top Genes: {top_genes}") -``` - -### Example Clinical Insights -- **Factor A**: High weights for `VIM`, `SNAI1`, `ZEB1` -> Discovered **EMT** pathway. -- **Factor B**: High weights for `TFF3`, `CHGA`, `MUC2` -> Discovered **Secretory/Goblet** cell signatures. - -By visualizing these factor scores as spatial heatmaps, you can see where these discovered biological processes are active in the tissue. diff --git a/docs/MODELS.md b/docs/MODELS.md index 3029aa0..f6c14a1 100644 --- a/docs/MODELS.md +++ b/docs/MODELS.md @@ -34,12 +34,8 @@ The SpatialTranscriptFormer models the **interaction between biological pathways By default, the model operates in **Full Interaction** mode where all four information flows are active. Users can selectively disable any combination using the `--interactions` flag to explore architectural variants: ```bash -# Default: Full Interaction (all quadrants enabled) ---interactions p2p p2h h2p h2h - -# Pathway Bottleneck: block H↔H to force all inter-patch -# communication through the pathway bottleneck ---interactions p2p p2h h2p +# Default: Small Interaction (CTransPath, 4 layers) +python scripts/run_preset.py --preset stf_small ``` > [!TIP] @@ -53,7 +49,7 @@ Three additional design principles support these interactions: - **Biological Initialisation** — The gene reconstruction weights are initialised from MSigDB Hallmark gene sets, providing a biologically-grounded starting point that the model refines during training. -### 2.2 Spatial Learning +## 2.2 Spatial Learning The spatial relationships of gene expression are central to this model. It is not sufficient to predict correct expression magnitudes at each spot independently — the model must capture **where** on the tissue pathways are active and how that spatial pattern varies across the slide. Two mechanisms enforce this: @@ -218,14 +214,19 @@ The model outputs these parameters, and the loss computes the negative log-likel To prevent bottleneck collapse and provide a direct gradient signal to the pathway tokens, we use the `AuxiliaryPathwayLoss`. This loss compares the model's internal pathway scores against "ground truth" pathway activations computed from the gene expression targets via MSigDB membership. +To prevent highly-expressed housekeeping genes from dominating the pathway's spatial pattern, the ground-truth targets are computed using **Z-score spatial normalization**: + +1. Every gene's spatial expression pattern is standardized (mean=0, variance=1) across the tissue slide. +2. The normalized genes are projected onto the binary MSigDB pathway matrix. +3. The resulting pathway scores are **mean-aggregated** (divided by the number of known member genes in each pathway) rather than raw-summed. + +This ensures every gene—including critical but lowly-expressed transcription factors—gets an equal vote in determining where a pathway is active. + The total objective becomes: $$\mathcal{L} = \mathcal{L}_{gene} + \lambda_{aux} (1 - \text{PCC}(\text{pathway\_scores}, \text{target\_pathways}))$$ The `--log-transform` flag applies `log1p` to targets, mitigating the heavy-tailed gene expression distribution where housekeeping genes dominate MSE. -The full training objective with pathway sparsity regularisation: -$$\mathcal{L} = \mathcal{L}_{task} + \lambda \|W_{recon}\|_1$$ - --- ## 5. CLI Flags (Model Configuration) @@ -239,7 +240,6 @@ $$\mathcal{L} = \mathcal{L}_{task} + \lambda \|W_{recon}\|_1$$ | `--n-layers` | 2 | Transformer layers (minimum 2) | | `--num-pathways` | 50 | Number of pathway bottleneck tokens | | `--pathway-init` | off | Initialize gene_reconstructor from MSigDB | -| `--sparsity-lambda` | 0.0 | L1 regularisation on reconstruction weights | | `--loss mse_pcc` | `mse` | Loss function (`mse`, `pcc`, `mse_pcc`, `zinb`) | | `--pcc-weight` | 1.0 | Weight for PCC term in composite loss | | `--pathway-loss-weight` | 0.0 | Weight for auxiliary pathway loss ($\lambda_{aux}$) | diff --git a/docs/PATHWAY_MAPPING.md b/docs/PATHWAY_MAPPING.md index 6d3f373..7239a8a 100644 --- a/docs/PATHWAY_MAPPING.md +++ b/docs/PATHWAY_MAPPING.md @@ -35,17 +35,11 @@ In this mode, the network receives direct supervision on its pathway tokens, gui 1. **Interaction**: Learnable pathway tokens $P$ interact with Histology patch features $H$ via self-attention (e.g., $p2h$, $h2p$). 2. **Activation**: Pathway scores $S \in \mathbb{R}^P$ are computed using a learnable temperature-scaled cosine similarity between the pathway tokens and image patch tokens. 3. **Gene Reconstruction**: $\hat{y} = S \cdot \mathbf{W}_{recon} + b$, where $\mathbf{W}_{recon}$ is initialized using the binary pathway membership matrix $M$. -- **MTL Auxiliary Loss**: To prevent standard bottleneck collapse, an explicit auxiliary loss bridges the spatial representations directly to biological data. The pathway scores $S$ are supervised against a pathway ground truth ($Y_{genes} \cdot M^T$) using a Pearson Correlation Coefficient (PCC) loss. - $$L_{total} = L_{gene} + \lambda_{pathway} (1 - PCC(S, Y_{genes} \cdot M^T))$$ -- **Benefit**: The model is forced to explicitly align its internal interaction tokens with concrete biological pathways, granting direct interpretability. - -#### 2. Data-Driven Discovery (Latent Projection) - -In the absence of a biological prior, the model can learn its own "latent pathways". - -- **Implementation**: $\mathbf{W}_{recon}$ is randomly initialized and the auxiliary pathway loss is disabled. -- **Sparsity Constraint**: We apply an L1 penalty to force the model to identify "canonical" sparse gene sets: $L_{total} = L_{gene} + \lambda_{sparsity} \|\mathbf{W}_{recon}\|_1$. -- **Benefit**: Can discover novel spatial-transcriptomic relationships that aren't yet captured in curated databases. +- **MTL Auxiliary Loss**: To prevent standard bottleneck collapse, an explicit auxiliary loss bridges the spatial representations directly to biological data. The pathway scores $S$ are supervised against a pathway ground truth using a Pearson Correlation Coefficient (PCC) loss. + - To prevent highly expressed housekeeping genes dominating the signal, the raw spatial gene counts ($Y_{genes}$) are first **spatially Z-score normalized** ($Z_{genes}$). + - These are then projected onto the pathway matrix and mean-aggregated by member count ($C$): + $$L_{total} = L_{gene} + \lambda_{pathway} (1 - PCC(S, \frac{Z_{genes} \cdot M^T}{C}))$$ +- **Benefit**: The model is forced to explicitly align its internal interaction tokens with concrete biological pathways, granting direct interpretability where every gene gets an equal vote. ## 3. Generalizing to HEST1k Tissues @@ -73,18 +67,12 @@ By supplying these functional groupings via `--custom-gmt`, the model's MTL proc - GMT file cached in `.cache/` after first download. - **Custom Pathway Definitions** (`--custom-gmt` flag): Users can override the default Hallmarks by providing a URL or local path to a `.gmt` file, enabling custom database integrations (e.g., KEGG, Reactome, or highly specific tissue masks). -- **Sparsity Regularization** (`--sparsity-lambda` flag): L1 penalty on `gene_reconstructor` weights to encourage pathway-like groupings when using data-driven (random) initialization. - ### Usage ```bash # With biological initialization (50 MSigDB Hallmarks) python -m spatial_transcript_former.train \ --model interaction --pathway-init ... - -# With data-driven pathways + sparsity -python -m spatial_transcript_former.train \ - --model interaction --num-pathways 50 --sparsity-lambda 0.01 ... ``` - **Spatial Pathway Maps**: Visualize pathway activations as spatial heatmaps overlaid on histology using `stf-predict`. See the [README](../README.md) for inference instructions. diff --git a/docs/TRAINING_GUIDE.md b/docs/TRAINING_GUIDE.md index 5a5e923..29f24d5 100644 --- a/docs/TRAINING_GUIDE.md +++ b/docs/TRAINING_GUIDE.md @@ -126,44 +126,22 @@ python -m spatial_transcript_former.train \ > **Note**: `--pathway-init` overrides `--num-pathways` to 50 (the number of Hallmark gene sets). The GMT file is cached in `.cache/` after first download. -### Data-Driven Discovery (Latent Pathways) +### Recommended: Using Presets -To allow the model to discover its own spatial-transcriptomic relationships without biological priors, omit `--pathway-init` and apply sparsity regularization (`--sparsity-lambda`). This aims to force the model to identify "canonical" sparse gene sets. +For most cases, it is recommended to use the provided presets: ```bash -python -m spatial_transcript_former.train \ - --data-dir A:\hest_data \ - --model interaction \ - --backbone ctranspath \ - --use-nystrom \ - --num-pathways 50 \ - --sparsity-lambda 0.01 \ - --precomputed \ - --whole-slide \ - --use-amp \ - --log-transform \ - --epochs 100 -``` - -> **Note**: Without `--pathway-init`, the model disables the `AuxiliaryPathwayLoss` and relies entirely on the main reconstruction objectives and the L1 sparsity penalty. (I am yet to obtain results with this method)... +# Tiny (2 layers, 256 dim) +python scripts/run_preset.py --preset stf_tiny -### Robust Counting: ZINB + Auxiliary Loss +# Small (4 layers, 384 dim) - Recommended +python scripts/run_preset.py --preset stf_small -For raw count data with high sparsity, using the ZINB distribution and auxiliary pathway supervision is recommended. +# Medium (6 layers, 512 dim) +python scripts/run_preset.py --preset stf_medium -```bash -python -m spatial_transcript_former.train \ - --data-dir A:\hest_data \ - --model interaction \ - --backbone ctranspath \ - --pathway-init \ - --loss zinb \ - --pathway-loss-weight 0.5 \ - --lr 5e-5 \ - --batch-size 4 \ - --whole-slide \ - --precomputed \ - --epochs 200 +# Large (12 layers, 768 dim) +python scripts/run_preset.py --preset stf_large ``` ### Choosing Interaction Modes @@ -201,7 +179,7 @@ Submit with: sbatch hpc/array_train.slurm ``` -### Collecting Results +### Collecting Results (Currently broken!) After experiments complete, aggregate all `results_summary.json` files into a comparison table: @@ -243,8 +221,8 @@ python -m spatial_transcript_former.train --resume --output-dir runs/my_experime | `--feature-dir` | Explicit path to precomputed features directory. | Overrides auto-detection. | | `--loss` | Loss function: `mse`, `pcc`, `mse_pcc`, `zinb`. | `mse_pcc` or `zinb` recommended. | | `--pathway-loss-weight` | Weight ($\lambda$) for auxiliary pathway supervision. | Set `0.5` or `1.0` with `interaction` model. | -| `--sparsity-lambda` | L1 regularization weight for discovering latent pathways. | Use `0.01` when `--pathway-init` is NOT used. | | `--interactions` | Enabled attention quadrants: `p2p`, `p2h`, `h2p`, `h2h`. | Default: `all` (Full Interaction). | +| `--plot-pathways-list` | Names of explicitly requested pathways to visualize as heatmaps during periodic validation. | Use with `--plot-pathways`. e.g. `HYPOXIA ANGIOGENESIS` | | `--log-transform` | Apply log1p to gene expression targets. | Recommended for raw count data. | | `--num-genes` | Number of HVGs to predict (default: 1000). | Match your `global_genes.json`. | | `--mask-radius` | Euclidean distance for spatial attention gating. | Usually between 200 and 800. | diff --git a/scripts/analyze_expression_variance.py b/scripts/analyze_expression_variance.py new file mode 100644 index 0000000..e73a0eb --- /dev/null +++ b/scripts/analyze_expression_variance.py @@ -0,0 +1,95 @@ +import os +import argparse +import numpy as np +import h5py +import matplotlib.pyplot as plt +import pandas as pd +import json + + +def analyze_sample(h5ad_path): + print(f"Analyzing {h5ad_path}...") + + with h5py.File(h5ad_path, "r") as f: + # Check standard AnnData structure + if "X" in f: + if isinstance(f["X"], h5py.Group): + # Sparse format (CSR/CSC) + data_group = f["X"]["data"][:] + n_cells = ( + f["obs"]["_index"].shape[0] + if "_index" in f["obs"] + else len(f["obs"]) + ) + n_genes = ( + f["var"]["_index"].shape[0] + if "_index" in f["var"] + else len(f["var"]) + ) + + print(f"Data is sparse, shape: ({n_cells}, {n_genes})") + print(f"Non-zero elements: {len(data_group)}") + + # Analyze non-zero elements + mean_val = np.mean(data_group) + max_val = np.max(data_group) + min_val = np.min(data_group) + + print(f"Non-zero Mean: {mean_val:.4f}") + print(f"Max Expression: {max_val:.4f}") + print(f"Min Expression: {min_val:.4f}") + + else: + # Dense array + X = f["X"][:] + print(f"Data is dense, shape: {X.shape}") + + # Basic stats + mean_exp = np.mean(X, axis=0) # per gene mean + var_exp = np.var(X, axis=0) # per gene variance + max_exp = np.max(X, axis=0) + + sparsity = np.sum(X == 0) / X.size + print(f"Overall Sparsity (zeros): {sparsity:.2%}") + + print( + f"Gene Mean Range: {np.min(mean_exp):.4f} to {np.max(mean_exp):.4f}" + ) + print(f"Gene Var Range: {np.min(var_exp):.4f} to {np.max(var_exp):.4f}") + print(f"Overall Max Expression: {np.max(max_exp):.4f}") + + # Check for extreme differences in variance + var_ratio = np.max(var_exp) / (np.min(var_exp) + 1e-8) + print(f"Ratio of max/min gene variance: {var_ratio:.4e}") + + return { + "sparsity": sparsity, + "var_ratio": var_ratio, + "max_exp": np.max(max_exp), + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-dir", + type=str, + default="A:\\hest_data", + help="Path to HEST data directory", + ) + args = parser.parse_args() + + st_dir = os.path.join(args.data_dir, "st") + if not os.path.exists(st_dir): + print(f"Error: Directory not found: {st_dir}") + exit(1) + + # Get a few random samples + samples = [f for f in os.listdir(st_dir) if f.endswith(".h5ad")] + if not samples: + print(f"No .h5ad files found in {st_dir}") + + # Analyze the first couple of samples + for sample in samples[:3]: + analyze_sample(os.path.join(st_dir, sample)) + print("-" * 50) diff --git a/scripts/inspect_outputs.py b/scripts/inspect_outputs.py new file mode 100644 index 0000000..0279923 --- /dev/null +++ b/scripts/inspect_outputs.py @@ -0,0 +1,122 @@ +import torch +import json +import os +import argparse +import numpy as np +from spatial_transcript_former.models import SpatialTranscriptFormer +from spatial_transcript_former.data.utils import get_sample_ids, setup_dataloaders + + +class Args: + pass + + +args = Args() +args.data_dir = "A:\\hest_data" +args.epochs = 2000 +args.output_dir = "runs/stf_tiny" +args.model = "interaction" +args.backbone = "ctranspath" +args.precomputed = True +args.whole_slide = True +args.pathway_init = True +args.use_amp = True +args.log_transform = True +args.loss = "mse_pcc" +args.resume = True +args.n_layers = 2 +args.token_dim = 256 +args.n_heads = 4 +args.batch_size = 1 +args.vis_sample = "TENX29" +args.max_samples = 1 +args.organ = None +args.num_genes = 1000 +args.n_neighbors = 6 +args.use_global_context = False +args.global_context_size = 0 +args.augment = False +args.feature_dir = None +args.seed = 42 +args.warmup_epochs = 10 +args.sparsity_lambda = 0.0 + +device = "cuda" if torch.cuda.is_available() else "cpu" + +genes_path = "global_genes.json" +with open(genes_path, "r") as f: + gene_list = json.load(f)[:1000] +args.num_genes = len(gene_list) + +final_ids = get_sample_ids( + args.data_dir, precomputed=args.precomputed, backbone=args.backbone, max_samples=1 +) +train_loader, _ = setup_dataloaders(args, final_ids, []) + +model = SpatialTranscriptFormer( + num_genes=args.num_genes, + backbone_name=args.backbone, + pretrained=False, + token_dim=args.token_dim, + n_heads=args.n_heads, + n_layers=args.n_layers, + num_pathways=50, + use_spatial_pe=True, + output_mode="counts", +) + +ckpt_path = os.path.join(args.output_dir, "latest_model_interaction.pth") +if os.path.exists(ckpt_path): + print("Loading", ckpt_path) + ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) + model.load_state_dict(ckpt["model_state_dict"], strict=False) +else: + print("No ckpt found!") + +model.to(device) +model.eval() + +with torch.no_grad(): + for batch in train_loader: + feats, genes, coords, mask = [x.to(device) for x in batch] + out = model(feats, rel_coords=coords, mask=mask, return_dense=True) + preds = out + + preds = torch.expm1(preds) if args.log_transform else preds + targets = torch.expm1(genes) if args.log_transform else genes + + patch_idx = None + for i in range(mask.shape[1]): + if not mask[0, i]: + patch_idx = i + break + + with open( + "C:/Users/wispy/.gemini/antigravity/brain/6a31ec6d-2f34-4f97-96b8-e437c2640219/model_output_sample.md", + "w", + ) as f: + f.write("# Model Output Sample (stf_tiny with simplifications)\n\n") + if patch_idx is not None: + f.write("### Target vs Prediction for a Single Valid Patch\n") + f.write("Showing the first 20 genes (absolute expression counts).\n\n") + + f.write("| Gene Index | Target Count (True) | Predicted Count |\n") + f.write("|------------|----------------------|-----------------|\n") + + t_vals = targets[0, patch_idx, :20].cpu().numpy() + p_vals = preds[0, patch_idx, :20].cpu().numpy() + + for i in range(20): + f.write(f"| {i} | {t_vals[i]:.2f} | {p_vals[i]:.2f} |\n") + + f.write("\n### Summary Statistics Across All Patches in Batch\n") + f.write(f"- Target Mean: {targets[~mask].mean().item():.4f}\n") + f.write(f"- Target Max: {targets[~mask].max().item():.4f}\n") + f.write(f"- Pred Mean: {preds[~mask].mean().item():.4f}\n") + f.write(f"- Pred Max: {preds[~mask].max().item():.4f}\n") + f.write(f"- Pred Min: {preds[~mask].min().item():.4f}\n") + else: + f.write("No valid patches found in sample.\n") + + print("Sample logic written to artifact.") + break diff --git a/scripts/migrate_logs_to_sqlite.py b/scripts/migrate_logs_to_sqlite.py new file mode 100644 index 0000000..2e5bbbd --- /dev/null +++ b/scripts/migrate_logs_to_sqlite.py @@ -0,0 +1,29 @@ +import os +import pandas as pd +import sqlite3 +import argparse + + +def migrate_csv_to_sqlite(run_dir): + csv_path = os.path.join(run_dir, "training_log.csv") + db_path = os.path.join(run_dir, "training_logs.sqlite") + + if not os.path.exists(csv_path): + print(f"No CSV found at {csv_path}") + return + + print(f"Migrating {csv_path} to {db_path}...") + df = pd.read_csv(csv_path) + + with sqlite3.connect(db_path) as conn: + df.to_sql("metrics", conn, if_exists="replace", index=False) + print("Done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--run-dir", type=str, required=True, help="Path to run directory" + ) + args = parser.parse_args() + migrate_csv_to_sqlite(args.run_dir) diff --git a/scripts/monitor.py b/scripts/monitor.py index a3270ce..a055edb 100644 --- a/scripts/monitor.py +++ b/scripts/monitor.py @@ -1,303 +1,54 @@ -import dash -from dash import dcc, html -from dash.dependencies import Input, Output -import plotly.graph_objs as go -import pandas as pd +#!/usr/bin/env python +""" +Real-time Training Monitor Entrypoint for SpatialTranscriptFormer. +""" import argparse -import os +import logging +from spatial_transcript_former.dashboard.app import init_app, app + + +def parse_args(): + parser = argparse.ArgumentParser(description="Real-time Training Monitor") + parser.add_argument( + "--run-dir", + type=str, + default=None, + help="Path to the experiment run directory containing training_logs.sqlite", + ) + parser.add_argument( + "--runs-dir", + type=str, + default=None, + help="Path to a directory containing MULTIPLE experiment run directories for comparison", + ) + parser.add_argument( + "--port", type=int, default=8050, help="Port to host the dashboard on" + ) + parser.add_argument( + "--interval", + type=int, + default=5000, + help="Log polling interval in milliseconds", + ) + args = parser.parse_args() + if not args.run_dir and not args.runs_dir: + parser.error("Must provide either --run-dir or --runs-dir") + return args -import flask -import glob - -# Set up argument parsing for the target run directory -parser = argparse.ArgumentParser(description="Real-time Training Monitor") -parser.add_argument( - "--run-dir", - type=str, - required=True, - help="Path to the experiment run directory containing training_log.csv", -) -parser.add_argument( - "--port", type=int, default=8050, help="Port to host the dashboard on" -) -parser.add_argument( - "--interval", type=int, default=5000, help="Log polling interval in milliseconds" -) -args = parser.parse_args() - -log_path = os.path.join(args.run_dir, "training_log.csv") - -# Configure Flask to serve images from the run directory -server = flask.Flask(__name__) - - -@server.route("/images/") -def serve_image(filename): - return flask.send_from_directory(os.path.abspath(args.run_dir), filename) - - -# Initialize Dash application -app = dash.Dash( - __name__, - server=server, - title=f"Training Monitor - {os.path.basename(args.run_dir)}", -) - -# Define application layout -app.layout = html.Div( - style={"fontFamily": "sans-serif", "padding": "20px"}, - children=[ - html.H1(f"Real-time Training Monitor: {os.path.basename(args.run_dir)}"), - html.Div( - id="last-updated", - style={"color": "gray", "fontStyle": "italic", "marginBottom": "10px"}, - ), - html.Button( - "Pause Updates", - id="pause-button", - n_clicks=0, - style={ - "marginBottom": "20px", - "padding": "10px", - "fontSize": "16px", - "cursor": "pointer", - "backgroundColor": "#f0f0f0", - "border": "1px solid #ccc", - "borderRadius": "5px", - }, - ), - html.Div( - [ - html.Label( - "Smoothing Window (Epochs):", - style={"fontWeight": "bold", "marginRight": "10px"}, - ), - dcc.Slider( - id="smoothing-slider", - min=1, - max=50, - step=1, - value=1, - marks={1: "1 (None)", 10: "10", 25: "25", 50: "50"}, - ), - ], - style={ - "marginBottom": "20px", - "padding": "10px", - "backgroundColor": "#f9f9f9", - "borderRadius": "5px", - }, - ), - html.Div( - [ - # Row 1: Losses + Correlation - html.Div( - [dcc.Graph(id="live-loss-graph", animate=False)], - style={ - "width": "48%", - "display": "inline-block", - "verticalAlign": "top", - }, - ), - html.Div( - [dcc.Graph(id="live-pcc-graph", animate=False)], - style={ - "width": "48%", - "display": "inline-block", - "verticalAlign": "top", - }, - ), - # Row 2: Variance + Learning Rate - html.Div( - [dcc.Graph(id="live-variance-graph", animate=False)], - style={ - "width": "48%", - "display": "inline-block", - "verticalAlign": "top", - }, - ), - html.Div( - [dcc.Graph(id="live-lr-graph", animate=False)], - style={ - "width": "48%", - "display": "inline-block", - "verticalAlign": "top", - }, - ), - ] - ), - html.Div( - [ - html.H3("Latest Inference Plot (Truth vs Pred)"), - html.Div(id="image-container"), - ], - style={ - "marginTop": "40px", - "textAlign": "center", - "backgroundColor": "#1a1a2e", - "padding": "20px", - "borderRadius": "10px", - }, - ), - # Hidden interval component for polling - dcc.Interval( - id="interval-component", - interval=args.interval, # in milliseconds - n_intervals=0, - disabled=False, - ), - ], -) - - -@app.callback( - Output("image-container", "children"), [Input("interval-component", "n_intervals")] -) -def update_image(n): - search_pattern = os.path.join(args.run_dir, "*.png") - list_of_files = glob.glob(search_pattern) - if not list_of_files: - return html.P( - "No inference plots found yet. Make sure to run training with --plot-pathways.", - style={"color": "red"}, - ) - - # Get the newest file - latest_file = max(list_of_files, key=os.path.getmtime) - filename = os.path.basename(latest_file) - - # Force reload by appending modifying timestamp query - mtime = os.path.getmtime(latest_file) - url = f"/images/{filename}?t={mtime}" - - return html.Img(src=url, style={"maxWidth": "100%", "height": "auto"}) - - -def _make_traces(df, cols, smoothing_window): - """Create Plotly traces for the given columns with optional smoothing.""" - traces = [] - for col in cols: - if col not in df.columns: - continue - y_data = df[col].dropna() - epochs = df.loc[y_data.index, "epoch"] - if smoothing_window and smoothing_window > 1: - y_data = y_data.rolling(window=smoothing_window, min_periods=1).mean() - traces.append(go.Scatter(x=epochs, y=y_data, mode="lines", name=col)) - return traces - - -@app.callback( - [ - Output("live-loss-graph", "figure"), - Output("live-pcc-graph", "figure"), - Output("live-variance-graph", "figure"), - Output("live-lr-graph", "figure"), - Output("last-updated", "children"), - ], - [Input("interval-component", "n_intervals"), Input("smoothing-slider", "value")], -) -def update_graphs(n, smoothing_window): - empty = dash.no_update - if not os.path.exists(log_path): - return empty, empty, empty, empty, "Waiting for training_log.csv..." - - try: - df = pd.read_csv(log_path) - except Exception as e: - return empty, empty, empty, empty, f"Error reading log: {e}" - - if df.empty or "epoch" not in df.columns: - return empty, empty, empty, empty, "Log empty or missing 'epoch'." - - margin = dict(l=40, r=40, t=40, b=40) - - # Chart 1: Losses (log scale) - loss_cols = [c for c in df.columns if "loss" in c.lower()] - loss_fig = { - "data": _make_traces(df, loss_cols, smoothing_window), - "layout": go.Layout( - title="Loss", - xaxis=dict(title="Epoch"), - yaxis=dict(title="Loss", type="log"), - margin=margin, - ), - } - - # Chart 2: Correlation (PCC, MAE) - corr_cols = [c for c in ["val_pcc", "val_mae"] if c in df.columns] - pcc_fig = { - "data": _make_traces(df, corr_cols, smoothing_window), - "layout": go.Layout( - title="Correlation & Error", - xaxis=dict(title="Epoch"), - yaxis=dict(title="Score"), - margin=margin, - ), - } - - # Chart 3: Prediction Variance - var_cols = [c for c in ["pred_variance"] if c in df.columns] - var_fig = { - "data": _make_traces(df, var_cols, smoothing_window), - "layout": go.Layout( - title="Prediction Variance (collapse detector)", - xaxis=dict(title="Epoch"), - yaxis=dict(title="Variance", type="log"), - margin=margin, - ), - } - - # Chart 4: Learning Rate - lr_cols = [c for c in ["lr"] if c in df.columns] - lr_fig = { - "data": _make_traces(df, lr_cols, smoothing_window), - "layout": go.Layout( - title="Learning Rate Schedule", - xaxis=dict(title="Epoch"), - yaxis=dict(title="LR", type="log"), - margin=margin, - ), - } - - last_epoch = df["epoch"].iloc[-1] - update_text = f"Last updated: Epoch {last_epoch} (Polled automatically)" - - return loss_fig, pcc_fig, var_fig, lr_fig, update_text +if __name__ == "__main__": + args = parse_args() -@app.callback( - [ - Output("interval-component", "disabled"), - Output("pause-button", "children"), - Output("pause-button", "style"), - ], - [Input("pause-button", "n_clicks")], -) -def toggle_pause(n_clicks): - base_style = { - "marginBottom": "20px", - "padding": "10px", - "fontSize": "16px", - "cursor": "pointer", - "borderRadius": "5px", - "border": "1px solid #ccc", - } - if n_clicks % 2 == 1: - # Paused state - active_style = { - **base_style, - "backgroundColor": "#ffcccc", - "borderColor": "#ff0000", - } - return True, "Resume Updates", active_style - # Active state - active_style = {**base_style, "backgroundColor": "#f0f0f0"} - return False, "Pause Updates", active_style + # Initialize the dash app + init_app(args) + if getattr(args, "runs_dir", None): + print(f"Tracking multiple runs in: {args.runs_dir}") + else: + print(f"Tracking single run at: {args.run_dir}") -if __name__ == "__main__": - print(f"Tracking log at: {log_path}") print(f"Starting dashboard on http://127.0.0.1:{args.port}/") + + # Run the server # Turn off debug to prevent double-reloading the data parser during polling app.run(debug=False, port=args.port) diff --git a/scripts/predict_sample.py b/scripts/predict_sample.py new file mode 100644 index 0000000..e6f76bb --- /dev/null +++ b/scripts/predict_sample.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +import argparse +import os +import torch +import json +from spatial_transcript_former.visualization import run_inference_plot + + +# Dummy class to hold loaded arguments +class RunArgs: + def __init__(self, **entries): + self.__dict__.update(entries) + + +def parse_args(): + parser = argparse.ArgumentParser("Predict sample pathways") + parser.add_argument( + "--sample-id", + required=True, + type=str, + help="Sample ID to run inference on (e.g. TENX156)", + ) + parser.add_argument( + "--run-dir", + required=True, + type=str, + help="Directory containing model weights and args.json", + ) + parser.add_argument( + "--output-dir", type=str, default=".", help="Where to save the output plot" + ) + parser.add_argument( + "--epoch", type=int, default=0, help="Epoch number to label the plot with" + ) + return parser.parse_args() + + +def main(): + cli_args = parse_args() + + # Load args from run_dir + args_path = os.path.join(cli_args.run_dir, "results_summary.json") + if not os.path.exists(args_path): + raise FileNotFoundError(f"Missing {args_path}") + + with open(args_path, "r") as f: + summary_dict = json.load(f) + run_args_dict = summary_dict.get("config", {}) + + run_args = RunArgs(**run_args_dict) + run_args.output_dir = cli_args.output_dir + run_args.run_dir = cli_args.run_dir + + # Optional arguments that might be missing from older args.json + if not hasattr(run_args, "log_transform"): + run_args.log_transform = False + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Re-initialize the model based on run_args + if run_args.model == "baseline": + from spatial_transcript_former.models import SpatialTranscriptFormer + + model = SpatialTranscriptFormer( + backbone=run_args.backbone, + num_genes=run_args.num_genes, + dropout=run_args.dropout, + n_neighbors=run_args.n_neighbors, + ) + elif run_args.model == "interaction": + from spatial_transcript_former.models import SpatialTranscriptFormer + + model = SpatialTranscriptFormer( + num_genes=run_args.num_genes, + backbone_name=run_args.backbone, + pretrained=run_args.pretrained, + token_dim=getattr(run_args, "token_dim", 384), + n_heads=getattr(run_args, "n_heads", 6), + n_layers=getattr(run_args, "n_layers", 4), + num_pathways=getattr(run_args, "num_pathways", 0), + use_spatial_pe=getattr(run_args, "use_spatial_pe", True), + output_mode="zinb" if getattr(run_args, "loss", "") == "zinb" else "counts", + interactions=getattr(run_args, "interactions", None), + ) + else: + raise ValueError(f"Unknown model type: {run_args.model}") + + model.to(device) + + # Note: we explicitly load the *best* model if it exists, otherwise the latest + ckpt_path = os.path.join(cli_args.run_dir, f"best_model_{run_args.model}.pth") + if not os.path.exists(ckpt_path): + ckpt_path = os.path.join(cli_args.run_dir, f"latest_model_{run_args.model}.pth") + + if os.path.exists(ckpt_path): + print(f"Loading checkpoint from {ckpt_path}...") + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) + if "model_state_dict" in checkpoint: + model.load_state_dict(checkpoint["model_state_dict"]) + else: + model.load_state_dict(checkpoint) + else: + print( + f"Warning: No checkpoint found in {cli_args.run_dir}. Using untrained model." + ) + + print(f"Running inference for sample {cli_args.sample_id}...") + run_inference_plot(model, run_args, cli_args.sample_id, cli_args.epoch, device) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_preset.py b/scripts/run_preset.py index fe3e423..97efac2 100644 --- a/scripts/run_preset.py +++ b/scripts/run_preset.py @@ -5,124 +5,73 @@ from spatial_transcript_former.config import get_config -# Common flags for all STF interaction models -STF_COMMON = [ - "--model", - "interaction", - "--backbone", - "ctranspath", - "--precomputed", - "--whole-slide", - "--pathway-init", - "--use-amp", - "--log-transform", - "--loss", - "mse_pcc", - "--resume", -] + +def make_stf_params(n_layers: int, token_dim: int, n_heads: int, batch_size: int): + """Helper to create standard SpatialTranscriptFormer parameters.""" + return { + "model": "interaction", + "backbone": "ctranspath", + "precomputed": True, + "whole-slide": True, + "pathway-init": True, + "use-amp": True, + "log-transform": True, + "loss": "mse_pcc", + "resume": True, + "n-layers": n_layers, + "token-dim": token_dim, + "n-heads": n_heads, + "batch-size": batch_size, + "vis_sample": "TENX29", + } + PRESETS = { # --- Baselines --- - "he2rna_baseline": [ - "--model", - "he2rna", - "--backbone", - "resnet50", - "--batch-size", - "64", - ], - "vit_baseline": [ - "--model", - "vit_st", - "--backbone", - "vit_b_16", - "--batch-size", - "32", - ], - "attention_mil": [ - "--model", - "attention_mil", - "--whole-slide", - "--precomputed", - "--batch-size", - "1", - ], - "transmil": [ - "--model", - "transmil", - "--whole-slide", - "--precomputed", - "--batch-size", - "1", - ], - # --- Interaction Models (Layer Scaling) --- - "stf_interaction_l2": STF_COMMON - + [ - "--n-layers", - "2", - "--token-dim", - "256", - "--n-heads", - "4", - "--batch-size", - "4", - ], - "stf_interaction_l4": STF_COMMON - + [ - "--n-layers", - "4", - "--token-dim", - "384", - "--n-heads", - "8", - "--batch-size", - "4", - ], - "stf_interaction_l6": STF_COMMON - + [ - "--n-layers", - "6", - "--token-dim", - "512", - "--n-heads", - "8", - "--batch-size", - "2", # Reduced batch size for large model memory - ], - # --- Specific Configurations --- - "stf_interaction_zinb": [ - "--model", - "interaction", - "--backbone", - "ctranspath", - "--precomputed", - "--whole-slide", - "--pathway-init", - "--sparsity-lambda", - "0", - "--lr", - "1e-4", - "--batch-size", - "4", - "--epochs", - "2500", - "--use-amp", - "--loss", - "zinb", - "--log-transform", - "--pathway-loss-weight", - "0.5", - "--interactions", - "p2p", - "p2h", - "h2p", - "h2h", - "--plot-pathways", - "--resume", - ], + "he2rna_baseline": { + "model": "he2rna", + "backbone": "resnet50", + "batch-size": 64, + }, + "vit_baseline": { + "model": "vit_st", + "backbone": "vit_b_16", + "batch-size": 32, + }, + "attention_mil": { + "model": "attention_mil", + "whole-slide": True, + "precomputed": True, + "batch-size": 1, + }, + "transmil": { + "model": "transmil", + "whole-slide": True, + "precomputed": True, + "batch-size": 1, + }, + # --- SpatialTranscriptFormer Variants --- + "stf_tiny": make_stf_params(n_layers=2, token_dim=256, n_heads=4, batch_size=8), + "stf_small": make_stf_params(n_layers=4, token_dim=384, n_heads=8, batch_size=8), + "stf_medium": make_stf_params(n_layers=6, token_dim=512, n_heads=8, batch_size=8), + "stf_large": make_stf_params(n_layers=12, token_dim=768, n_heads=12, batch_size=8), } +def params_to_args(params_dict): + """Convert a parameter dictionary to a list of CLI arguments.""" + args = [] + for key, value in params_dict.items(): + arg_name = f"--{key.replace('_', '-')}" + if value is True: + args.append(arg_name) + elif value is False or value is None: + continue + else: + args.extend([arg_name, str(value)]) + return args + + def main(): parser = argparse.ArgumentParser( description="Run Spatial TranscriptFormer training presets" @@ -169,7 +118,7 @@ def main(): cmd += ["--output-dir", f"./runs/{args.preset}"] # Add preset arguments - cmd += PRESETS[args.preset] + cmd += params_to_args(PRESETS[args.preset]) # Add any unknown arguments passed to this script cmd += unknown diff --git a/src/spatial_transcript_former/dashboard/__init__.py b/src/spatial_transcript_former/dashboard/__init__.py new file mode 100644 index 0000000..1cbef4b --- /dev/null +++ b/src/spatial_transcript_former/dashboard/__init__.py @@ -0,0 +1,7 @@ +""" +Dashboard package for SpatialTranscriptFormer model monitoring. +""" + +from .app import app, server + +__all__ = ["app", "server"] diff --git a/src/spatial_transcript_former/dashboard/app.py b/src/spatial_transcript_former/dashboard/app.py new file mode 100644 index 0000000..0fbac17 --- /dev/null +++ b/src/spatial_transcript_former/dashboard/app.py @@ -0,0 +1,43 @@ +import dash +import flask +import os +import argparse +import logging +from .layout import create_layout +from .callbacks import register_callbacks + +# Configure Python logging (app level) +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + +# Global references (assigned during init) +server = flask.Flask(__name__) +app = dash.Dash(__name__, server=server, suppress_callback_exceptions=True) + + +def init_app(args): + """Initialize layout, routes, and callbacks with the given arguments.""" + + # Configure Flask to serve images from the run directory + @server.route("/images//") + @server.route("/images/") + def serve_image(filename, run_name=None): + if run_name and getattr(args, "runs_dir", None): + directory = os.path.join(os.path.abspath(args.runs_dir), run_name) + elif getattr(args, "run_dir", None): + directory = os.path.abspath(args.run_dir) + else: + return "Not found", 404 + + return flask.send_from_directory(directory, filename) + + app.title = ( + f"Training Monitor - {os.path.basename(args.run_dir)}" + if getattr(args, "run_dir", None) + else "Training Monitor: Compare Runs" + ) + app.layout = create_layout(args) + register_callbacks(app, args) + + return app diff --git a/src/spatial_transcript_former/dashboard/callbacks.py b/src/spatial_transcript_former/dashboard/callbacks.py new file mode 100644 index 0000000..2f5cb9b --- /dev/null +++ b/src/spatial_transcript_former/dashboard/callbacks.py @@ -0,0 +1,415 @@ +import os +import dash +import pandas as pd +from dash.dependencies import Input, Output, State +from dash import html, dcc +import plotly.graph_objs as go + +from .data_access import get_training_data, get_available_images +from .layout import create_kpi_card + + +def register_callbacks(app, args): + """Register all Dashboard callbacks.""" + + from .data_access import get_available_runs + + @app.callback( + [Output("run-selector", "options"), Output("run-selector", "value")], + [Input("interval-component", "n_intervals")], + [State("run-selector", "value")], + ) + def update_run_selector(n, current_selected): + runs = get_available_runs(args) + options = [{"label": r["name"], "value": r["name"]} for r in runs] + + # If no runs selected, default to the first one + if not current_selected and runs: + return options, [runs[0]["name"]] + return options, dash.no_update + + @app.callback( + Output("download-data", "data"), + Input("export-button", "n_clicks"), + [State("run-selector", "value")], + prevent_initial_call=True, + ) + def export_data(n_clicks, selected_runs): + data_dict = get_training_data(args, selected_runs) + if not data_dict: + return dash.no_update + + # Combine into one exportable CSV + combined = [] + for run_name, df in data_dict.items(): + df_copy = df.copy() + df_copy.insert(0, "run_name", run_name) + combined.append(df_copy) + + final_df = pd.concat(combined, ignore_index=True) + return dcc.send_data_frame(final_df.to_csv, "training_metrics.csv", index=False) + + @app.callback( + [ + Output("interval-component", "disabled"), + Output("pause-button", "children"), + Output("pause-button", "style"), + ], + [Input("pause-button", "n_clicks")], + ) + def toggle_pause(n_clicks): + base_style = { + "padding": "10px 20px", + "fontSize": "14px", + "cursor": "pointer", + "border": "none", + "borderRadius": "8px", + "fontWeight": "bold", + "transition": "background-color 0.2s", + } + if n_clicks % 2 == 1: + active_style = { + **base_style, + "backgroundColor": "#ef4444", + "color": "white", + } + return True, "Resume Updates", active_style + + active_style = {**base_style, "backgroundColor": "#2563eb", "color": "white"} + return False, "Pause Updates", active_style + + @app.callback( + [Output("sample-dropdown", "options"), Output("sample-dropdown", "value")], + [Input("interval-component", "n_intervals"), Input("run-selector", "value")], + [State("sample-dropdown", "value")], + ) + def update_sample_dropdown(n, selected_runs, current_val): + images = get_available_images(args, selected_runs) + samples = sorted(list(set(img["sample"] for img in images))) + options = [{"label": s, "value": s} for s in samples] + + default_val = ( + current_val if current_val in samples else (samples[0] if samples else None) + ) + return options, default_val + + @app.callback( + [Output("epoch-dropdown", "options"), Output("epoch-dropdown", "value")], + [ + Input("sample-dropdown", "value"), + Input("interval-component", "n_intervals"), + Input("run-selector", "value"), + ], + [State("epoch-dropdown", "value")], + ) + def update_epoch_dropdown(selected_sample, n, selected_runs, current_epoch): + if not selected_sample: + return [], None + + images = get_available_images(args, selected_runs) + + # Since we might compare across runs, an epoch might be available for multiple runs + epochs = sorted( + list( + set(img["epoch"] for img in images if img["sample"] == selected_sample) + ), + reverse=True, + ) + options = [{"label": f"Epoch {e}", "value": e} for e in epochs] + + default_val = ( + current_epoch + if current_epoch in epochs + else (epochs[0] if epochs else None) + ) + return options, default_val + + @app.callback( + Output("image-container", "children"), + [ + Input("sample-dropdown", "value"), + Input("epoch-dropdown", "value"), + Input("run-selector", "value"), + ], + ) + def display_image(sample, epoch, selected_runs): + if not sample or not epoch: + return html.Div( + "Select a sample and epoch to view predictions.", + style={"color": "#64748b", "padding": "50px"}, + ) + + images = get_available_images(args, selected_runs) + + # We might have matches from multiple runs. We'll show them side by side. + matches = [ + img for img in images if img["sample"] == sample and img["epoch"] == epoch + ] + + if not matches: + return html.Div("Image not found.", style={"color": "#ef4444"}) + + # If it's a multi-run directory, we need to map the image path correctly to a Flask route. + # But our Flask route expects only filenames and serves from args.run_dir. + # This will be tricky if we have an `--runs-dir`. We will need to update the server route. + # For now, we'll construct the HTML to expect a new API. + + children = [] + for match in matches: + url = f"/images/{match['run_name']}/{match['filename']}?t={match['mtime']}" + children.append( + html.Div( + [ + html.H4( + match["run_name"], + style={"color": "#38bdf8", "textAlign": "center"}, + ), + html.Img( + src=url, + style={ + "maxWidth": "100%", + "height": "auto", + "objectFit": "contain", + "borderRadius": "8px", + "marginBottom": "20px", + }, + ), + ], + style={"flex": "1", "minWidth": "300px", "padding": "10px"}, + ) + ) + + return html.Div( + children=children, + style={ + "display": "flex", + "flexWrap": "wrap", + "gap": "20px", + "width": "100%", + "justifyContent": "center", + }, + ) + + def _make_traces(data_dict, cols, smoothing_window): + """Create Plotly traces for the given columns across multiple runs.""" + traces = [] + + # Color palette for runs + colors = ["#38bdf8", "#fb7185", "#a3e635", "#c084fc", "#facc15", "#2dd4bf"] + + for r_idx, (run_name, df) in enumerate(data_dict.items()): + color = colors[r_idx % len(colors)] + + for c_idx, col in enumerate(cols): + if col not in df.columns: + continue + y_data = df[col].dropna() + epochs = df.loc[y_data.index, "epoch"] + if smoothing_window and smoothing_window > 1: + y_data = y_data.rolling( + window=smoothing_window, min_periods=1 + ).mean() + + # Use solid lines for primary metric, dashed for secondary if multiple cols + dash_style = "solid" if c_idx == 0 else "dash" + + label_name = ( + f"{run_name}" + if len(cols) == 1 + else f"{run_name} ({col.replace('_', ' ').title()})" + ) + + traces.append( + go.Scatter( + x=epochs, + y=y_data, + mode="lines", + name=label_name, + line=dict(width=2.5, color=color, dash=dash_style), + showlegend=True, + ) + ) + return traces + + @app.callback( + [ + Output("live-loss-graph", "figure"), + Output("live-pcc-graph", "figure"), + Output("live-variance-graph", "figure"), + Output("live-lr-graph", "figure"), + Output("live-cpu-graph", "figure"), + Output("live-ram-graph", "figure"), + Output("live-gpu-graph", "figure"), + Output("last-updated", "children"), + Output("kpi-cards", "children"), + ], + [ + Input("interval-component", "n_intervals"), + Input("smoothing-slider", "value"), + Input("run-selector", "value"), + ], + ) + def update_metrics(n, smoothing_window, selected_runs): + empty = dash.no_update + data_dict = get_training_data(args, selected_runs) + + if not data_dict: + return ( + empty, + empty, + empty, + empty, + empty, + empty, + empty, + "Waiting for training data...", + [html.Div("No data yet or no runs selected", style={"color": "white"})], + ) + + # Common layout styles for dark mode charts + layout_defaults = dict( + plot_bgcolor="#1e293b", + paper_bgcolor="#1e293b", + font=dict(color="#cbd5e1"), + margin=dict(l=50, r=20, t=50, b=50), + xaxis=dict(gridcolor="#334155", zerolinecolor="#334155"), + legend=dict( + orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 + ), + ) + # Separate base yaxis style to mix in + yaxis_base = dict(gridcolor="#334155", zerolinecolor="#334155") + + # Losses + loss_cols = ["train_loss", "val_loss"] + loss_fig = go.Figure(data=_make_traces(data_dict, loss_cols, smoothing_window)) + loss_fig.update_layout( + title="Loss Landscape", + yaxis_type="log", + yaxis_title="Loss (Log Scale)", + yaxis=yaxis_base, + xaxis_title="Epoch", + **layout_defaults, + ) + + # Correlation / Errors + corr_cols = ["val_pcc", "val_mae"] + pcc_fig = go.Figure(data=_make_traces(data_dict, corr_cols, smoothing_window)) + pcc_fig.update_layout( + title="Validation Metrics", + yaxis_title="Score", + yaxis=yaxis_base, + xaxis_title="Epoch", + **layout_defaults, + ) + + # Variance + var_cols = ["pred_variance"] + var_fig = go.Figure(data=_make_traces(data_dict, var_cols, smoothing_window)) + var_fig.update_layout( + title="Prediction Variance", + yaxis_type="log", + yaxis_title="Variance", + yaxis=yaxis_base, + xaxis_title="Epoch", + **layout_defaults, + ) + + # Learning Rate + lr_cols = ["lr"] + lr_fig = go.Figure(data=_make_traces(data_dict, lr_cols, smoothing_window)) + lr_fig.update_layout( + title="Learning Rate", + yaxis_type="log", + yaxis_title="LR", + yaxis=yaxis_base, + xaxis_title="Epoch", + **layout_defaults, + ) + + # Hardware Metrics + cpu_cols = ["sys_cpu_percent"] + cpu_fig = go.Figure(data=_make_traces(data_dict, cpu_cols, smoothing_window)) + cpu_fig.update_layout( + title="CPU Usage", + yaxis_title="%", + yaxis={**yaxis_base, "range": [0, 100]}, + xaxis_title="Epoch", + **layout_defaults, + ) + + ram_cols = ["sys_ram_percent"] + ram_fig = go.Figure(data=_make_traces(data_dict, ram_cols, smoothing_window)) + ram_fig.update_layout( + title="RAM Usage", + yaxis_title="%", + yaxis={**yaxis_base, "range": [0, 100]}, + xaxis_title="Epoch", + **layout_defaults, + ) + + gpu_cols = ["sys_gpu_mem_mb"] + gpu_fig = go.Figure(data=_make_traces(data_dict, gpu_cols, smoothing_window)) + gpu_fig.update_layout( + title="GPU Memory", + yaxis_title="MB", + yaxis=yaxis_base, + xaxis_title="Epoch", + **layout_defaults, + ) + + # KPI Data - only show for the first selected run (or if only 1 run, that run) + # to avoid blowing up the UI with 20 cards. + target_run_name = list(data_dict.keys())[0] if data_dict else None + kpi_elements = [] + update_text = "Data Loaded" + + if target_run_name: + df = data_dict[target_run_name] + last_row = df.iloc[-1] + last_epoch = int(last_row["epoch"]) + + run_lbl = f"{target_run_name} @ " if len(data_dict) > 1 else "" + + if "train_loss" in df.columns: + kpi_elements.append( + create_kpi_card( + "Train Loss", + f"{last_row['train_loss']:.4f}", + f"{run_lbl}Epoch {last_epoch}", + ) + ) + if "val_loss" in df.columns: + kpi_elements.append( + create_kpi_card( + "Val Loss", + f"{last_row['val_loss']:.4f}", + f"{run_lbl}Epoch {last_epoch}", + ) + ) + if "val_pcc" in df.columns: + kpi_elements.append( + create_kpi_card( + "Val PCC", + f"{last_row['val_pcc']:.4f}", + f"{run_lbl}Epoch {last_epoch}", + ) + ) + if "lr" in df.columns: + kpi_elements.append( + create_kpi_card("Learning Rate", f"{last_row['lr']:.2e}") + ) + + update_text = f"Last updated: {target_run_name} Epoch {last_epoch} (Live)" + + return ( + loss_fig, + pcc_fig, + var_fig, + lr_fig, + cpu_fig, + ram_fig, + gpu_fig, + update_text, + kpi_elements, + ) diff --git a/src/spatial_transcript_former/dashboard/data_access.py b/src/spatial_transcript_former/dashboard/data_access.py new file mode 100644 index 0000000..39917ed --- /dev/null +++ b/src/spatial_transcript_former/dashboard/data_access.py @@ -0,0 +1,156 @@ +import os +import sqlite3 +import pandas as pd +import glob +import logging +from threading import Lock + +# Simple thread-safe cache for images to avoid globbing disc constantly +_image_cache = {"last_check": 0, "images": []} +_cache_lock = Lock() + + +import os +import sqlite3 +import pandas as pd +import glob +import logging +from threading import Lock + +# Simple thread-safe cache for images to avoid globbing disc constantly +_image_cache = {"last_check": 0, "images": []} +_cache_lock = Lock() + + +def get_available_runs(args): + """Returns a list of dicts with name and path for available runs.""" + runs = [] + if getattr(args, "run_dir", None): + if os.path.exists(args.run_dir): + runs.append( + { + "name": os.path.basename(os.path.normpath(args.run_dir)), + "path": args.run_dir, + } + ) + + if getattr(args, "runs_dir", None) and os.path.exists(args.runs_dir): + # Scan immediate subdirectories + for entry in os.scandir(args.runs_dir): + if entry.is_dir() and not entry.name.startswith("."): # Ignore hidden + runs.append({"name": entry.name, "path": entry.path}) + + # Sort runs alphabetically by name for consistency + runs.sort(key=lambda x: x["name"]) + return runs + + +def get_db_path(run_dir): + return os.path.join(run_dir, "training_logs.sqlite") + + +def _fetch_run_metrics(run_dir): + """Fetch all rows from the metrics table in the SQLite database for a single run.""" + db_path = get_db_path(run_dir) + + if not os.path.exists(db_path): + # Fallback to CSV if DB doesn't exist yet (for backwards compat) + csv_path = os.path.join(run_dir, "training_log.csv") + if os.path.exists(csv_path): + try: + return pd.read_csv(csv_path) + except Exception as e: + logging.error(f"Failed to read CSV: {e}") + return pd.DataFrame() + return pd.DataFrame() + + try: + with sqlite3.connect(db_path) as conn: + # Check if table exists + cursor = conn.cursor() + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='metrics'" + ) + if not cursor.fetchone(): + return pd.DataFrame() + + query = "SELECT * FROM metrics ORDER BY epoch ASC" + df = pd.read_sql_query(query, conn) + return df + except Exception as e: + logging.error(f"Database error reading metrics: {e}") + return pd.DataFrame() + + +def get_training_data(args, selected_runs=None): + """ + Fetch metric data across selected runs. + If selected_runs is None, defaults to all available runs. + Returns: Dict mapping run_name -> DataFrame. + """ + all_runs = get_available_runs(args) + if selected_runs: + runs_to_fetch = [r for r in all_runs if r["name"] in selected_runs] + else: + runs_to_fetch = all_runs + + data_dict = {} + for run in runs_to_fetch: + df = _fetch_run_metrics(run["path"]) + if not df.empty: + data_dict[run["name"]] = df + + return data_dict + + +def get_available_images(args, selected_runs=None, cache_ttl=10): + """Scans for inference plot images and extracts metadata, caching results across runs.""" + import time + + with _cache_lock: + now = time.time() + # Need to ensure cache keying is valid, but keeping simple for now + # Refresh if TTL expired or cache is totally empty + if now - _image_cache["last_check"] < cache_ttl and _image_cache["images"]: + all_imgs = _image_cache["images"] + else: + all_runs = get_available_runs(args) + parsed_images = [] + + for run in all_runs: + search_pattern = os.path.join(run["path"], "*.png") + files = glob.glob(search_pattern) + + for file in files: + basename = os.path.basename(file) + # Expected format: SAMPLEID_epoch_NUM.png + try: + parts = basename.replace(".png", "").split("_epoch_") + if len(parts) == 2: + sample_id = parts[0] + epoch = int(parts[1]) + mtime = os.path.getmtime(file) + parsed_images.append( + { + "filename": basename, + "run_name": run["name"], + "run_path": run["path"], + "sample": sample_id, + "epoch": epoch, + "mtime": mtime, + } + ) + except Exception: + pass # Skip improperly named files + + # Sort by epoch descending then run alphabetical + parsed_images.sort(key=lambda x: (x["epoch"], x["run_name"]), reverse=True) + + _image_cache["last_check"] = now + _image_cache["images"] = parsed_images + all_imgs = parsed_images + + # Filter by selected runs after cache fetch + if selected_runs: + return [img for img in all_imgs if img["run_name"] in selected_runs] + return all_imgs diff --git a/src/spatial_transcript_former/dashboard/layout.py b/src/spatial_transcript_former/dashboard/layout.py new file mode 100644 index 0000000..1765832 --- /dev/null +++ b/src/spatial_transcript_former/dashboard/layout.py @@ -0,0 +1,358 @@ +import dash +from dash import dcc, html + + +def create_layout(args): + """Generates the main dashboard layout.""" + import os + + run_dir = getattr(args, "run_dir", None) + run_name = os.path.basename(run_dir) if run_dir else "Multiple Runs" + + return html.Div( + style={ + "fontFamily": "Inter, Roboto, sans-serif", + "padding": "20px", + "backgroundColor": "#0f172a", + "color": "#f1f5f9", + "minHeight": "100vh", + }, + children=[ + # Header + html.Div( + style={ + "display": "flex", + "justifyContent": "space-between", + "alignItems": "center", + "marginBottom": "20px", + }, + children=[ + html.H1( + ( + f"Training Monitor: {run_name}" + if getattr(args, "run_dir", None) + else "Training Monitor: Compare Runs" + ), + style={"color": "#38bdf8", "margin": "0"}, + ), + html.Div( + id="last-updated", + style={"color": "#94a3b8", "fontStyle": "italic"}, + ), + ], + ), + # Controls + html.Div( + style={ + "display": "flex", + "gap": "20px", + "padding": "20px", + "backgroundColor": "#1e293b", + "borderRadius": "12px", + "marginBottom": "20px", + "alignItems": "center", + "flexWrap": "wrap", # Allow wrapping if many controls + }, + children=[ + html.Button( + "Pause Updates", + id="pause-button", + n_clicks=0, + style={ + "padding": "10px 20px", + "fontSize": "14px", + "cursor": "pointer", + "backgroundColor": "#2563eb", + "color": "white", + "border": "none", + "borderRadius": "8px", + "fontWeight": "bold", + "transition": "background-color 0.2s", + }, + ), + html.Div( + style={"flex": "2", "minWidth": "250px"}, + children=[ + html.Label( + "Select Runs:", + style={ + "fontWeight": "bold", + "color": "#cbd5e1", + "display": "block", + "marginBottom": "5px", + }, + ), + dcc.Dropdown( + id="run-selector", + options=[], # Populated by callback + value=[], + multi=True, + style={"color": "black"}, + placeholder="Select runs to compare...", + ), + ], + ), + html.Div( + style={"flex": "1", "minWidth": "200px"}, + children=[ + html.Label( + "Smoothing (Epochs):", + style={ + "fontWeight": "bold", + "color": "#cbd5e1", + "display": "block", + "marginBottom": "5px", + }, + ), + dcc.Slider( + id="smoothing-slider", + min=1, + max=50, + step=1, + value=1, + marks={ + i: {"label": str(i), "style": {"color": "#cbd5e1"}} + for i in [1, 10, 25, 50] + }, + ), + ], + ), + html.Button( + "Export Data", + id="export-button", + n_clicks=0, + style={ + "padding": "10px 20px", + "fontSize": "14px", + "cursor": "pointer", + "backgroundColor": "#10b981", # Emerald + "color": "white", + "border": "none", + "borderRadius": "8px", + "fontWeight": "bold", + }, + ), + dcc.Download( + id="download-data" + ), # Component to handle the actual file download + ], + ), + # KPI Cards (Top Row) + html.Div( + id="kpi-cards", + style={ + "display": "grid", + "gridTemplateColumns": "repeat(auto-fit, minmax(200px, 1fr))", + "gap": "20px", + "marginBottom": "20px", + }, + # Children populated by callback + ), + # Charts + html.Div( + style={ + "display": "grid", + "gridTemplateColumns": "1fr 1fr", + "gap": "20px", + "marginBottom": "30px", + }, + children=[ + dcc.Graph( + id="live-loss-graph", + animate=False, + style={ + "height": "400px", + "borderRadius": "12px", + "overflow": "hidden", + }, + ), + dcc.Graph( + id="live-pcc-graph", + animate=False, + style={ + "height": "400px", + "borderRadius": "12px", + "overflow": "hidden", + }, + ), + dcc.Graph( + id="live-variance-graph", + animate=False, + style={ + "height": "400px", + "borderRadius": "12px", + "overflow": "hidden", + }, + ), + dcc.Graph( + id="live-lr-graph", + animate=False, + style={ + "height": "400px", + "borderRadius": "12px", + "overflow": "hidden", + }, + ), + ], + ), + # Hardware Resource Charts + html.Div( + style={ + "display": "grid", + "gridTemplateColumns": "repeat(auto-fit, minmax(300px, 1fr))", + "gap": "20px", + "marginBottom": "30px", + }, + children=[ + dcc.Graph( + id="live-cpu-graph", + animate=False, + style={ + "height": "300px", + "borderRadius": "12px", + "overflow": "hidden", + }, + ), + dcc.Graph( + id="live-ram-graph", + animate=False, + style={ + "height": "300px", + "borderRadius": "12px", + "overflow": "hidden", + }, + ), + dcc.Graph( + id="live-gpu-graph", + animate=False, + style={ + "height": "300px", + "borderRadius": "12px", + "overflow": "hidden", + }, + ), + ], + ), + # Image Preview Section + html.Div( + style={ + "backgroundColor": "#1e293b", + "padding": "30px", + "borderRadius": "12px", + }, + children=[ + html.H2( + "Spatial Predictions", + style={"color": "#e2e8f0", "marginTop": "0"}, + ), + # Controls for image + html.Div( + style={ + "display": "flex", + "gap": "20px", + "marginBottom": "20px", + }, + children=[ + html.Div( + style={"flex": "1"}, + children=[ + html.Label( + "Sample ID:", + style={ + "color": "#cbd5e1", + "display": "block", + "marginBottom": "5px", + }, + ), + dcc.Dropdown( + id="sample-dropdown", + options=[], # Populated dynamically + style={ + "color": "black" + }, # Text color inside dropdown + ), + ], + ), + html.Div( + style={"flex": "1"}, + children=[ + html.Label( + "Epoch:", + style={ + "color": "#cbd5e1", + "display": "block", + "marginBottom": "5px", + }, + ), + dcc.Dropdown( + id="epoch-dropdown", + options=[], # Populated dynamically based on sample + style={"color": "black"}, + ), + ], + ), + ], + ), + html.Div( + id="image-container", + style={ + "textAlign": "center", + "minHeight": "400px", + "display": "flex", + "alignItems": "center", + "justifyContent": "center", + "backgroundColor": "#0f172a", + "borderRadius": "8px", + }, + ), + ], + ), + # Polling Interval + dcc.Interval( + id="interval-component", + interval=args.interval, + n_intervals=0, + disabled=False, + ), + ], + ) + + +def create_kpi_card(title, value, subtitle=""): + """Helper to create a stylized KPI card.""" + from dash import html + + return html.Div( + style={ + "backgroundColor": "#1e293b", + "padding": "20px", + "borderRadius": "12px", + "boxShadow": "0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)", + "borderLeft": "4px solid #38bdf8", + }, + children=[ + html.H3( + title, + style={ + "margin": "0 0 10px 0", + "color": "#94a3b8", + "fontSize": "14px", + "textTransform": "uppercase", + }, + ), + html.Div( + value, + style={ + "fontSize": "28px", + "fontWeight": "bold", + "color": "#f8fafc", + "marginBottom": "5px", + }, + ), + ( + html.Div(subtitle, style={"fontSize": "12px", "color": "#64748b"}) + if subtitle + else None + ), + ], + ) diff --git a/src/spatial_transcript_former/data/utils.py b/src/spatial_transcript_former/data/utils.py index 77e6a05..8e6d58f 100644 --- a/src/spatial_transcript_former/data/utils.py +++ b/src/spatial_transcript_former/data/utils.py @@ -6,7 +6,12 @@ def get_sample_ids( - data_dir, precomputed=False, backbone="resnet50", feature_dir=None, max_samples=None + data_dir, + precomputed=False, + backbone="resnet50", + feature_dir=None, + max_samples=None, + organ=None, ): """ Find and filter HEST sample IDs based on metadata and data availability. @@ -40,8 +45,14 @@ def get_sample_ids( # Filter for existing files and Homo sapiens df_filtered = df[df["id"].isin(available_ids)] df_human = df_filtered[df_filtered["species"] == "Homo sapiens"] - human_ids = df_human["id"].tolist() + if organ: + print(f"Filtering for organ: {organ}") + df_human = df_human[df_human["organ"] == organ] + if df_human.empty: + print(f"Warning: No samples found for organ '{organ}'.") + + human_ids = df_human["id"].tolist() final_ids = human_ids if not final_ids: print("Warning: No Human samples found. Using all files.") @@ -108,56 +119,72 @@ def setup_dataloaders(args, train_ids, val_ids): feat_dir = os.path.join(args.data_dir, "patches", feat_dir_name) if args.whole_slide: - train_loader = get_hest_feature_dataloader( - args.data_dir, - train_ids, - batch_size=args.batch_size, - shuffle=True, - num_genes=args.num_genes, - n_neighbors=args.n_neighbors, - whole_slide_mode=True, - augment=args.augment, - feature_dir=feat_dir, - log1p=args.log_transform, + train_loader = ( + get_hest_feature_dataloader( + args.data_dir, + train_ids, + batch_size=args.batch_size, + shuffle=True, + num_genes=args.num_genes, + n_neighbors=args.n_neighbors, + whole_slide_mode=True, + augment=args.augment, + feature_dir=feat_dir, + log1p=args.log_transform, + ) + if train_ids + else None ) - val_loader = get_hest_feature_dataloader( - args.data_dir, - val_ids, - batch_size=args.batch_size, - shuffle=False, - num_genes=args.num_genes, - n_neighbors=args.n_neighbors, - whole_slide_mode=True, - augment=False, - feature_dir=feat_dir, - log1p=args.log_transform, + val_loader = ( + get_hest_feature_dataloader( + args.data_dir, + val_ids, + batch_size=args.batch_size, + shuffle=False, + num_genes=args.num_genes, + n_neighbors=args.n_neighbors, + whole_slide_mode=True, + augment=False, + feature_dir=feat_dir, + log1p=args.log_transform, + ) + if val_ids + else None ) else: - train_loader = get_hest_feature_dataloader( - args.data_dir, - train_ids, - batch_size=args.batch_size, - shuffle=True, - num_genes=args.num_genes, - n_neighbors=args.n_neighbors, - use_global_context=args.use_global_context, - global_context_size=args.global_context_size, - augment=args.augment, - feature_dir=feat_dir, - log1p=args.log_transform, + train_loader = ( + get_hest_feature_dataloader( + args.data_dir, + train_ids, + batch_size=args.batch_size, + shuffle=True, + num_genes=args.num_genes, + n_neighbors=args.n_neighbors, + use_global_context=args.use_global_context, + global_context_size=args.global_context_size, + augment=args.augment, + feature_dir=feat_dir, + log1p=args.log_transform, + ) + if train_ids + else None ) - val_loader = get_hest_feature_dataloader( - args.data_dir, - val_ids, - batch_size=args.batch_size, - shuffle=False, - num_genes=args.num_genes, - n_neighbors=args.n_neighbors, - use_global_context=args.use_global_context, - global_context_size=args.global_context_size, - augment=False, - feature_dir=feat_dir, - log1p=args.log_transform, + val_loader = ( + get_hest_feature_dataloader( + args.data_dir, + val_ids, + batch_size=args.batch_size, + shuffle=False, + num_genes=args.num_genes, + n_neighbors=args.n_neighbors, + use_global_context=args.use_global_context, + global_context_size=args.global_context_size, + augment=False, + feature_dir=feat_dir, + log1p=args.log_transform, + ) + if val_ids + else None ) else: # Base normalization @@ -180,26 +207,34 @@ def setup_dataloaders(args, train_ids, val_ids): if args.use_global_context: print("Warning: Global context only supported with pre-computed features.") - train_loader = get_hest_dataloader( - args.data_dir, - train_ids, - batch_size=args.batch_size, - shuffle=True, - num_genes=args.num_genes, - n_neighbors=args.n_neighbors, - transform=train_transform, - augment=args.augment, - log1p=args.log_transform, + train_loader = ( + get_hest_dataloader( + args.data_dir, + train_ids, + batch_size=args.batch_size, + shuffle=True, + num_genes=args.num_genes, + n_neighbors=args.n_neighbors, + transform=train_transform, + augment=args.augment, + log1p=args.log_transform, + ) + if train_ids + else None ) - val_loader = get_hest_dataloader( - args.data_dir, - val_ids, - batch_size=args.batch_size, - shuffle=False, - num_genes=args.num_genes, - n_neighbors=args.n_neighbors, - transform=val_transform, - log1p=args.log_transform, + val_loader = ( + get_hest_dataloader( + args.data_dir, + val_ids, + batch_size=args.batch_size, + shuffle=False, + num_genes=args.num_genes, + n_neighbors=args.n_neighbors, + transform=val_transform, + log1p=args.log_transform, + ) + if val_ids + else None ) return train_loader, val_loader diff --git a/src/spatial_transcript_former/models/interaction.py b/src/spatial_transcript_former/models/interaction.py index 77ace8a..743e308 100644 --- a/src/spatial_transcript_former/models/interaction.py +++ b/src/spatial_transcript_former/models/interaction.py @@ -154,20 +154,24 @@ def __init__( ) self.fusion_engine = nn.TransformerEncoder( - encoder_layer, num_layers=n_layers, enable_nested_tensor=False + encoder_layer, + num_layers=n_layers, + norm=nn.LayerNorm(token_dim), + enable_nested_tensor=False, ) - # Learnable temperature for cosine similarity scoring - # Initialized to log(1/0.07) ≈ 2.66 following CLIP convention - self.log_temperature = nn.Parameter(torch.tensor(2.6593)) - self.gene_reconstructor = nn.Linear(num_pathways, num_genes) if pathway_init is not None: with torch.no_grad(): # gene_reconstructor.weight is (num_genes, num_pathways) # pathway_init is (num_pathways, num_genes) - self.gene_reconstructor.weight.copy_(pathway_init.T) + # L1 normalization roughly matches variance scale of Kaiming initialization + # so that outputs don't explode before Softplus is applied. + pathway_init_norm = pathway_init / ( + pathway_init.sum(dim=1, keepdim=True) + 1e-6 + ) + self.gene_reconstructor.weight.copy_(pathway_init_norm.T) self.gene_reconstructor.bias.zero_() # Expose the MSigDB matrix for AuxiliaryPathwayLoss self._pathway_init_matrix = pathway_init.clone() @@ -341,27 +345,22 @@ def forward( # Extract processed patch tokens processed_patch_tokens = out[:, p:, :] # (B, S, D) - # 5. Compute pathway scores via cosine similarity with learnable temperature + # 5. Compute pathway scores via cosine similarity # L2-normalize both sets of tokens to produce cosine similarities in [-1, 1] norm_pathway = F.normalize(processed_pathway_tokens, dim=-1) # (B, P, D) - temperature = self.log_temperature.exp() # scalar if return_dense: # Dense prediction: per-patch cosine similarity with pathway tokens norm_patch = F.normalize(processed_patch_tokens, dim=-1) # (B, S, D) # (B, S, D) @ (B, D, P) -> (B, S, P) - pathway_scores = ( - torch.matmul(norm_patch, norm_pathway.transpose(1, 2)) * temperature - ) + pathway_scores = torch.matmul(norm_patch, norm_pathway.transpose(1, 2)) else: # Global prediction: pool patches first, then compute scores global_patch_token = processed_patch_tokens.mean( dim=1, keepdim=True ) # (B, 1, D) norm_global = F.normalize(global_patch_token, dim=-1) # (B, 1, D) - pathway_scores = ( - torch.matmul(norm_global, norm_pathway.transpose(1, 2)) * temperature - ) + pathway_scores = torch.matmul(norm_global, norm_pathway.transpose(1, 2)) pathway_scores = pathway_scores.squeeze(1) # (B, P) # Gene reconstruction (unified for both modes) @@ -372,7 +371,8 @@ def forward( theta = F.softplus(self.theta_reconstructor(pathway_scores)) + 1e-6 gene_expression = (pi, mu, theta) else: - gene_expression = self.gene_reconstructor(pathway_scores) + # Enforce non-negativity for gene counts (log1p or raw) + gene_expression = F.softplus(self.gene_reconstructor(pathway_scores)) results = [gene_expression] if return_pathways: diff --git a/src/spatial_transcript_former/predict.py b/src/spatial_transcript_former/predict.py index 36cd0da..6df3245 100644 --- a/src/spatial_transcript_former/predict.py +++ b/src/spatial_transcript_former/predict.py @@ -1,579 +1,152 @@ -import argparse import os -import torch import numpy as np import matplotlib.pyplot as plt -import seaborn as sns -from torch.utils.data import DataLoader - -from spatial_transcript_former.models.regression import HE2RNA, ViT_ST -from spatial_transcript_former.models.interaction import SpatialTranscriptFormer -from spatial_transcript_former.data.dataset import ( - get_hest_dataloader, - HEST_Dataset, - load_gene_expression_matrix, - load_global_genes, -) -import h5py - - -def plot_spatial_genes( - coords, truth, pred, gene_names, sample_id, save_path=None, cmap="viridis" -): - """ - Plots spatial maps of gene expression (dot plot format). - """ - num_plots = min(len(gene_names), 5) - fig, axes = plt.subplots(num_plots, 2, figsize=(12, 4 * num_plots)) - if num_plots == 1: - axes = np.expand_dims(axes, axis=0) - - for i in range(num_plots): - gene_name = gene_names[i] - - # Truth - # HEST coords are (x, y). Scatter takes (x, y). - # imshow is y-down. - sc = axes[i, 0].scatter( - coords[:, 0], - coords[:, 1], - c=truth[:, i], - cmap=cmap, - s=10, - edgecolors="none", - ) - axes[i, 0].set_title(f"{sample_id} - {gene_name} (TRUTH)") - axes[i, 0].invert_yaxis() # Match image space - axes[i, 0].set_aspect("equal") - plt.colorbar(sc, ax=axes[i, 0], shrink=0.6) - - # Prediction - sc = axes[i, 1].scatter( - coords[:, 0], coords[:, 1], c=pred[:, i], cmap=cmap, s=10, edgecolors="none" - ) - axes[i, 1].set_title(f"{sample_id} - {gene_name} (PRED)") - axes[i, 1].invert_yaxis() # Match image space - axes[i, 1].set_aspect("equal") - plt.colorbar(sc, ax=axes[i, 1], shrink=0.6) - - plt.tight_layout() - if save_path: - plt.savefig(save_path) - print(f"Plot saved to {save_path}") - else: - plt.show() - plt.close(fig) - plt.close("all") - - -def plot_histology_overlay( - image, - coords, - values, - gene_names, - sample_id, - scalef=1.0, - save_path=None, - cmap="viridis", -): - """ - Plots predictions as an overlay on the histology image. - """ - num_plots = min(len(gene_names), 3) - fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 6)) - if num_plots == 1: - axes = [axes] - - # Project coords to image pixels - pixel_coords = coords * scalef - - for i in range(num_plots): - gene_name = gene_names[i] - ax = axes[i] - - # Display Histology - ax.imshow(image) - - # Overlay Predictions - # pixel_coords[:, 0] = X, pixel_coords[:, 1] = Y - sc = ax.scatter( - pixel_coords[:, 0], - pixel_coords[:, 1], - c=values[:, i], - cmap=plt.get_cmap(cmap), - alpha=0.4, - s=15, - edgecolors="none", - ) - - ax.set_title(f"{sample_id} - {gene_name} Overlay") - ax.axis("off") - plt.colorbar(sc, ax=ax, fraction=0.046, pad=0.04) - - plt.tight_layout() - if save_path: - plt.savefig(save_path) - print(f"Overlay saved to {save_path}") - else: - plt.show() - plt.close(fig) - plt.close("all") - - -# Fixed representative pathways for visualization (MSigDB Hallmark names) -CORE_PATHWAYS = [ - "EPITHELIAL_MESENCHYMAL_TRANSITION", - "WNT_BETA_CATENIN_SIGNALING", - "INFLAMMATORY_RESPONSE", - "ANGIOGENESIS", - "APOPTOSIS", - "TNFA_SIGNALING_VIA_NFKB", -] +from typing import List, Optional def plot_training_summary( - coords, - pathway_pred, - pathway_truth, - pathway_names, - sample_id, - histology_img=None, - scalef=1.0, - save_path=None, - cmap="jet", + coords: np.ndarray, + pathway_pred: np.ndarray, + pathway_truth: np.ndarray, + pathway_names: List[str], + sample_id: str = "Sample", + histology_img: Optional[np.ndarray] = None, + scalef: float = 1.0, + save_path: str = "plot.png", + plot_pathways_list: Optional[List[str]] = None, ): """ - Compact landscape training visualization dashboard without z-scoring. - Colorbars are explicitly externalized using make_axes_locatable so plots - stay uniform in size. + Creates a detailed summary plot of pathway predictions vs ground truth. Args: - coords: (N, 2) spot coordinates. - pathway_pred: (N, P) predicted pathway activations. - pathway_truth: (N, P) ground-truth pathway activations. - pathway_names: List of pathway names (length P). - sample_id: Sample identifier for titles. - histology_img: Optional H&E image array. - scalef: Scale factor for projecting coords onto histology. - save_path: Where to save the figure. - cmap: Colormap for scatter plots. + coords: (N, 2) spatial coordinates. + pathway_pred: (N, P) predicted pathway activations (spatial z-score). + pathway_truth: (N, P) ground truth pathway activations (spatial z-score). + pathway_names: List of P pathway names. + sample_id: Identifier for the plot title. + histology_img: Optional RGB image for background. + scalef: Scale factor for histology image. + save_path: Path to save the PNG. + plot_pathways_list: Optional list of specific pathway names to plot. If None, plots the first 6 available. """ - import matplotlib.gridspec as gridspec - from mpl_toolkits.axes_grid1 import make_axes_locatable - - # Find indices for our fixed pathways - name_to_idx = {} - if pathway_names is not None: - for i, name in enumerate(pathway_names): - short = name.replace("HALLMARK_", "") - name_to_idx[short] = i + # Format short names for easier matching + short_names = [n.replace("HALLMARK_", "") for n in pathway_names] - display_pathways = [] - if pathway_names is not None: - if len(pathway_names) <= 6: - for pw in pathway_names: - short = pw.replace("HALLMARK_", "") - if short in name_to_idx: - display_pathways.append((short, name_to_idx[short])) - else: - for pw in CORE_PATHWAYS: - if pw in name_to_idx: - display_pathways.append((pw, name_to_idx[pw])) + selected_indices = [] + plot_names = [] - if not display_pathways: - for pw in pathway_names[:6]: - short = pw.replace("HALLMARK_", "") - if short in name_to_idx: - display_pathways.append((short, name_to_idx[short])) - - if not display_pathways: - print("Warning: No viable pathways found in model to display. Skipping plot.") + # Determine which pathways to plot + if plot_pathways_list is not None and len(plot_pathways_list) > 0: + target_pathways = [p.replace("HALLMARK_", "") for p in plot_pathways_list] + else: + # Default to the first 6 pathways available if none specified + target_pathways = short_names[:6] + + for target_pw in target_pathways: + if target_pw in short_names: + idx = short_names.index(target_pw) + selected_indices.append(idx) + plot_names.append(target_pw) + elif f"HALLMARK_{target_pw}" in pathway_names: + # Fallback exact match if user supplied the full name + idx = pathway_names.index(f"HALLMARK_{target_pw}") + selected_indices.append(idx) + plot_names.append(target_pw) + + if not selected_indices: + print( + f"Warning: None of the requested pathways '{target_pathways}' were found in the model's output." + ) return - n_per_row = 2 - n_pw = len(display_pathways) - n_rows = int(np.ceil(n_pw / n_per_row)) - has_histology = histology_img is not None - - vis_coords = coords * scalef if has_histology else coords + num_plots = len(selected_indices) - # Create figure: Width accommodates 1 Histology + (2 pathways * 3 cols each (Truth, Pred, Cbar)) - # Total ~7 logical columns - fig = plt.figure(figsize=(24, 6 * n_rows), constrained_layout=False) - fig.patch.set_facecolor("#1a1a2e") + # 1. Apply style BEFORE creating subplots to ensure all text/axes inherit it properly + plt.style.use("dark_background") - # Outer Grid: 1 col for Histology, 1 for all pathways - outer = gridspec.GridSpec( - 1, - 2, - figure=fig, - width_ratios=[1, 3.5], - left=0.02, - right=0.98, - top=0.92, - bottom=0.05, - wspace=0.1, + # 2. Adjust width to be less extremely stretched (3 per plot instead of 5) + width = max(10, 3 * num_plots) + height = 8 + fig, axes = plt.subplots( + 2, num_plots, figsize=(width, height), squeeze=False, layout="constrained" ) - # --- Left: Histology panel --- - ax_hist = fig.add_subplot(outer[0, 0]) - if has_histology: - ax_hist.imshow(histology_img) - ax_hist.set_title("Histology", fontsize=16, color="white", pad=12) - ax_hist.axis("off") - ax_hist.set_facecolor("#0d0d1a") - # Anchor to top so it doesn't float randomly if pathways are tall - ax_hist.set_anchor("N") - - # --- Right: Pathway Grids --- - # For each pathway, we want [GT | Pred | Colorbar] = 3 sub-columns. - # Total columns = n_per_row * 3 - n_cols = n_per_row * 3 - # Configure width ratios: Maps get width 1, Colorbars get width 0.1 - col_widths = [1, 1, 0.1] * n_per_row + # Set the specific dark slate background color early + fig.patch.set_facecolor("#0f172a") - inner = gridspec.GridSpecFromSubplotSpec( - n_rows, - n_cols, - subplot_spec=outer[0, 1], - width_ratios=col_widths, - hspace=0.35, - wspace=0.15, + plt.suptitle( + f"Pathway Validation (Spatial Z-Score): {sample_id}", + fontsize=18, + color="white", + fontweight="bold", ) - for idx, (pw_name, pw_idx) in enumerate(display_pathways): - row = idx // n_per_row - pw_col_base = (idx % n_per_row) * 3 - - col_gt = pw_col_base - col_pred = pw_col_base + 1 - col_cbar = pw_col_base + 2 + for i, idx in enumerate(selected_indices): + name = plot_names[i] - label = pw_name.replace("_", " ").title() - if len(label) > 30: - label = label[:27] + "..." + # Pred and Truth for this pathway (z-score scale) + p = pathway_pred[:, idx] + t = pathway_truth[:, idx] - truth_vals = pathway_truth[:, pw_idx] - pred_vals = pathway_pred[:, pw_idx] + # Vmin/Vmax for shared z-score scale + vmin = min(p.min(), t.min()) + vmax = max(p.max(), t.max()) - # Both truth and pred are now in the same units (mean log1p expression - # of pathway member genes), so shared bounds give a fair comparison - vmin = min(truth_vals.min(), pred_vals.min()) - vmax = max(truth_vals.max(), pred_vals.max()) + for j, (vals, title) in enumerate([(t, "Truth"), (p, "Prediction")]): + ax = axes[j, i] - sc = None - for col, vals, suffix in [ - (col_gt, truth_vals, "Truth"), - (col_pred, pred_vals, "Pred"), - ]: - ax = fig.add_subplot(inner[row, col]) - ax.set_facecolor("#0d0d1a") - - if has_histology: - ax.imshow(histology_img, alpha=0.25) + if histology_img is not None: + ax.imshow(histology_img, alpha=0.4) + # Apply scale factor to coordinates if histology is present + c_plot = coords * scalef + else: + c_plot = coords sc = ax.scatter( - vis_coords[:, 0], - vis_coords[:, 1], + c_plot[:, 0], + c_plot[:, 1], c=vals, - cmap=cmap, - s=6, - edgecolors="none", + cmap="viridis", + s=2, + alpha=1.0, vmin=vmin, vmax=vmax, + edgecolors="none", ) - if not has_histology: - ax.invert_yaxis() - ax.set_aspect("equal") - ax.axis("off") - ax.set_title(f"{label}\n{suffix}", fontsize=12, color="white", pad=6) + # Titles on top row + if j == 0: + ax.set_title(name, fontsize=14, pad=10) + # Row labels on first column + if i == 0: + ax.set_ylabel(title, fontsize=14, labelpad=10) - # Plot the colorbar in its dedicated axis so it NEVER shrinks the prediction map - cax = fig.add_subplot(inner[row, col_cbar]) - cb = plt.colorbar(sc, cax=cax) - cb.ax.tick_params(labelsize=9, colors="white") - cb.outline.set_edgecolor("white") + ax.set_xticks([]) + ax.set_yticks([]) - fig.suptitle( - f"{sample_id} — Pathway Activation Summary", - fontsize=20, - fontweight="bold", - color="white", - y=0.98, - ) + # Remove bounding box for cleaner spatial visualization + for spine in ax.spines.values(): + spine.set_visible(False) - if save_path: - plt.savefig( - save_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor() - ) - print(f"Training summary saved to {save_path}") - else: - plt.show() - plt.close(fig) - plt.close("all") - - -def main(): - parser = argparse.ArgumentParser( - description="Predict and Visualize Spatial Transcriptomics" - ) - parser.add_argument("--data-dir", type=str, required=True) - parser.add_argument("--sample-id", type=str, required=True) - parser.add_argument("--model-path", type=str, required=True) - parser.add_argument( - "--model-type", - type=str, - default="he2rna", - choices=["he2rna", "vit_st", "interaction"], - ) - parser.add_argument("--num-genes", type=int, default=1000) - parser.add_argument("--output-dir", type=str, default="./results") - parser.add_argument( - "--n-neighbors", type=int, default=0, help="Number of spatial neighbors to use" - ) - parser.add_argument("--token-dim", type=int, default=256) - parser.add_argument("--n-heads", type=int, default=4) - parser.add_argument("--n-layers", type=int, default=2) - parser.add_argument( - "--num-pathways", - type=int, - default=50, - help="Number of pathways in the bottleneck", - ) - parser.add_argument( - "--backbone", - type=str, - default="resnet50", - help="Backbone for feature extraction", - ) - parser.add_argument( - "--plot-pathways", action="store_true", help="Visualize pathway activations" - ) - parser.add_argument( - "--loss", - type=str, - default="mse", - choices=["mse", "pcc", "mse_pcc", "zinb", "poisson", "logcosh"], - help="Loss function used for training (needed for model reconstruction)", - ) - - args = parser.parse_args() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Paths - patches_dir = os.path.join(args.data_dir, "patches") - if not os.path.exists(patches_dir): - patches_dir = args.data_dir - st_dir = os.path.join(args.data_dir, "st") - - h5_path = os.path.join(patches_dir, f"{args.sample_id}.h5") - h5ad_path = os.path.join(st_dir, f"{args.sample_id}.h5ad") - - # Load Model - if args.model_type == "he2rna": - model = HE2RNA(num_genes=args.num_genes, backbone=args.backbone) - elif args.model_type == "vit_st": - model = ViT_ST(num_genes=args.num_genes, model_name=args.backbone) - elif args.model_type == "interaction": - model = SpatialTranscriptFormer( - num_genes=args.num_genes, - token_dim=args.token_dim, - n_heads=args.n_heads, - n_layers=args.n_layers, - num_pathways=args.num_pathways, - backbone_name=args.backbone, - output_mode="zinb" if args.loss == "zinb" else "counts", - ) - - state_dict = torch.load(args.model_path, map_location=device, weights_only=True) - # Handle possible torch.compile prefix - new_state_dict = {} - for k, v in state_dict.items(): - if k.startswith("_orig_mod."): - new_state_dict[k[len("_orig_mod.") :]] = v - else: - new_state_dict[k] = v - model.load_state_dict(new_state_dict) - model.to(device) - model.eval() - - # Load Sample - with h5py.File(h5_path, "r") as f: - patch_barcodes = f["barcode"][:].flatten() - coords = f["coords"][:] - - # Load common genes for consistency - try: - common_gene_names = load_global_genes(args.data_dir, args.num_genes) - except Exception as e: - print( - f"Warning: Could not load global genes, falling back to sample top genes: {e}" - ) - common_gene_names = None - - gene_matrix, mask, gene_names = load_gene_expression_matrix( - h5ad_path, - patch_barcodes, - selected_gene_names=common_gene_names, - num_genes=args.num_genes, - ) - - coord_subset = coords[mask] - indices = np.where(mask)[0] + # Make subplot background transparent or match figure + ax.set_facecolor("#0f172a") - # Neighborhood computation if requested - neighborhood_indices = None - if args.n_neighbors > 0: - from scipy.spatial import KDTree - - with h5py.File(h5_path, "r") as f: - coords_all = f["coords"][:] - tree = KDTree(coords_all) - dists, idxs = tree.query(coord_subset, k=args.n_neighbors + 1) - final_neighbors = [] - for i, center_idx in enumerate(indices): - n_idxs = idxs[i] - n_idxs = n_idxs[n_idxs != center_idx] - final_neighbors.append(n_idxs[: args.n_neighbors]) - neighborhood_indices = np.array(final_neighbors) - else: - with h5py.File(h5_path, "r") as f: - coords_all = f["coords"][:] - - dataset = HEST_Dataset( - h5_path, - coord_subset, - gene_matrix, - indices=indices, - neighborhood_indices=neighborhood_indices, - coords_all=coords_all, - ) - loader = DataLoader(dataset, batch_size=32, shuffle=False) - - all_preds = [] - all_truth = [] - all_pathways = [] - - print(f"Running inference on {len(dataset)} patches...") - with torch.no_grad(): - for images, targets, rel_coords in loader: - images = images.to(device) - rel_coords = rel_coords.to(device) - - if isinstance(model, SpatialTranscriptFormer): - # Request pathways if argument set - output = model( - images, rel_coords=rel_coords, return_pathways=args.plot_pathways - ) - if isinstance(output, tuple) and args.plot_pathways: - preds = output[0] - pathways = output[1] - all_pathways.append(pathways.cpu().numpy()) - else: - preds = output - else: - preds = model(images) - - # Unpack ZINB tuple if generated - if isinstance(preds, tuple): - preds = preds[1] # Use mean component - - all_preds.append(preds.cpu().numpy()) - all_truth.append(targets.numpy()) - - all_preds = np.concatenate(all_preds, axis=0) - all_truth = np.concatenate(all_truth, axis=0) - if all_pathways: - all_pathways = np.concatenate(all_pathways, axis=0) - - # Save results - os.makedirs(args.output_dir, exist_ok=True) + if histology_img is not None: + # Invert Y to match histology orientation + ax.invert_yaxis() - # Standard Plot - save_dot_path = os.path.join( - args.output_dir, f"{args.sample_id}_spatial_inference.png" - ) - plot_spatial_genes( - coord_subset, - all_truth, - all_preds, - gene_names[:5], - args.sample_id, - save_path=save_dot_path, - cmap="jet", + # Add a colorbar spanning the figure + cbar = fig.colorbar( + sc, + ax=axes.ravel().tolist(), + orientation="horizontal", + shrink=0.5, + aspect=40, + pad=0.05, ) + cbar.set_label("Relative Expression (Spatial Z-Score)", fontsize=14, labelpad=10) + cbar.ax.tick_params(labelsize=12) - if args.plot_pathways and len(all_pathways) > 0: - save_pathway_path = os.path.join( - args.output_dir, f"{args.sample_id}_pathway_activations.png" - ) - plot_spatial_pathways( - coord_subset, - all_pathways, - args.sample_id, - save_path=save_pathway_path, - cmap="jet", - ) - - # Histology Overlay Plot - try: - # Construct h5ad path robustly - h5ad_overlay_path = h5_path.replace(".h5", ".h5ad") - if "patches" in h5ad_overlay_path: - h5ad_overlay_path = h5ad_overlay_path.replace("patches", "st") - - if os.path.exists(h5ad_overlay_path): - with h5py.File(h5ad_overlay_path, "r") as f: - # Robust group access - if "uns" in f and "spatial" in f["uns"]: - spatial = f["uns/spatial"] - sample_key = ( - list(spatial.keys())[0] if len(spatial.keys()) > 0 else None - ) - if sample_key: - img_group = spatial[sample_key]["images"] - img_key = ( - "downscaled_fullres" - if "downscaled_fullres" in img_group - else list(img_group.keys())[0] - ) - img = img_group[img_key][:] - - scale_group = spatial[sample_key]["scalefactors"] - scale_key = ( - "tissue_downscaled_fullres_scalef" - if "tissue_downscaled_fullres_scalef" in scale_group - else list(scale_group.keys())[0] - ) - scalef = scale_group[scale_key][()] - - save_overlay_path = os.path.join( - args.output_dir, f"{args.sample_id}_histology_overlay.png" - ) - print(f"Generating histology overlay for {args.sample_id}...") - plot_histology_overlay( - img, - coord_subset, - all_preds, - gene_names, - args.sample_id, - scalef=scalef, - save_path=save_overlay_path, - cmap="jet", - ) - else: - print( - f"No 'uns/spatial' found in {h5ad_overlay_path}. Keys: {list(f.keys())}" - ) - else: - print(f"H5AD file not found for overlay: {h5ad_overlay_path}") - except Exception as e: - print(f"Warning: Could not generate histology overlay: {e}") - import traceback - - traceback.print_exc() - - -if __name__ == "__main__": - main() + plt.savefig(save_path, dpi=200, bbox_inches="tight", facecolor="#0f172a") + plt.close(fig) + print(f"Saved pathway summary to {save_path}") diff --git a/src/spatial_transcript_former/train.py b/src/spatial_transcript_former/train.py index 70c5fb1..286be45 100644 --- a/src/spatial_transcript_former/train.py +++ b/src/spatial_transcript_former/train.py @@ -26,331 +26,12 @@ from spatial_transcript_former.visualization import run_inference_plot from spatial_transcript_former.data.utils import get_sample_ids, setup_dataloaders -# --------------------------------------------------------------------------- -# Model Setup -# --------------------------------------------------------------------------- - - -def setup_model(args, device): - """Initialize and optionally compile the model.""" - if args.model == "he2rna": - model = HE2RNA( - num_genes=args.num_genes, backbone=args.backbone, pretrained=args.pretrained - ) - elif args.model == "vit_st": - model = ViT_ST( - num_genes=args.num_genes, - model_name=args.backbone if "vit_" in args.backbone else "vit_b_16", - pretrained=args.pretrained, - ) - elif args.model == "interaction": - print( - f"Initializing SpatialTranscriptFormer ({args.backbone}, pretrained={args.pretrained})" - ) - - # Load biological pathway initialization if requested - pathway_init = None - if getattr(args, "pathway_init", False): - from spatial_transcript_former.data.pathways import ( - get_pathway_init, - MSIGDB_URLS, - ) - import json - - genes_path = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))), - "global_genes.json", - ) - if not os.path.exists(genes_path): - genes_path = "global_genes.json" - with open(genes_path) as f: - gene_list = json.load(f) - - if getattr(args, "custom_gmt", None): - urls = args.custom_gmt - elif getattr(args, "pathways", None): - # If specific pathways requested but no custom GMT, search standard collections - urls = [ - MSIGDB_URLS["hallmarks"], - MSIGDB_URLS["c2_medicus"], - MSIGDB_URLS["c2_cgp"], - ] - else: - # Default to just the 50 Hallmarks to prevent VRAM exhaustion - urls = [MSIGDB_URLS["hallmarks"]] - - pathway_init, pathway_names = get_pathway_init( - gene_list[: args.num_genes], gmt_urls=urls, filter_names=args.pathways - ) - # Override num_pathways based on actual parsed paths - args.num_pathways = len(pathway_names) - print(f"Num pathways forced to {args.num_pathways} based on init dict") - - model = SpatialTranscriptFormer( - num_genes=args.num_genes, - backbone_name=args.backbone, - pretrained=args.pretrained, - token_dim=args.token_dim, - n_heads=args.n_heads, - n_layers=args.n_layers, - num_pathways=args.num_pathways, - pathway_init=pathway_init, - use_spatial_pe=args.use_spatial_pe, - output_mode="zinb" if args.loss == "zinb" else "counts", - interactions=getattr(args, "interactions", None), - ) - elif args.model == "attention_mil": - from spatial_transcript_former.models.mil import AttentionMIL - - model = AttentionMIL( - output_dim=args.num_genes, - backbone_name=args.backbone, - pretrained=args.pretrained, - ) - elif args.model == "transmil": - from spatial_transcript_former.models.mil import TransMIL - - model = TransMIL( - output_dim=args.num_genes, - backbone_name=args.backbone, - pretrained=args.pretrained, - ) - else: - raise ValueError(f"Unknown model: {args.model}") - - model.weak_supervision = getattr(args, "weak_supervision", False) - model = model.to(device) - - if args.compile: - print(f"Compiling model (backend='{args.compile_backend}')...") - try: - model = torch.compile(model, backend=args.compile_backend) - except Exception as e: - print(f"Compilation failed: {e}. Using eager mode.") - - return model - - -def setup_criterion(args, pathway_init=None): - """Create loss function from CLI args. - - If ``pathway_init`` is provided and ``--pathway-loss-weight > 0``, - wraps the base criterion with :class:`AuxiliaryPathwayLoss`. - """ - if args.loss == "pcc": - base = PCCLoss() - elif args.loss == "mse_pcc": - base = CompositeLoss(alpha=args.pcc_weight) - elif args.loss == "zinb": - base = ZINBLoss() - elif args.loss == "poisson": - base = nn.PoissonNLLLoss(log_input=True) - elif args.loss == "logcosh": - print("Using HuberLoss as proxy for LogCosh") - base = nn.HuberLoss() - else: - base = MaskedMSELoss() - - pw_weight = getattr(args, "pathway_loss_weight", 0.0) - if pathway_init is not None and pw_weight > 0: - from spatial_transcript_former.training.losses import AuxiliaryPathwayLoss - - print(f"Wrapping criterion with AuxiliaryPathwayLoss (lambda={pw_weight})") - return AuxiliaryPathwayLoss(pathway_init, base, lambda_pathway=pw_weight) - - return base - - -# --------------------------------------------------------------------------- -# Checkpoint Management -# --------------------------------------------------------------------------- - - -def save_checkpoint( - model, optimizer, scaler, epoch, best_val_loss, output_dir, model_name -): - """Save training state for resuming.""" - save_dict = { - "epoch": epoch, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "best_val_loss": best_val_loss, - } - if scaler is not None: - save_dict["scaler_state_dict"] = scaler.state_dict() - - torch.save(save_dict, os.path.join(output_dir, f"latest_model_{model_name}.pth")) - - -def load_checkpoint(model, optimizer, scaler, output_dir, model_name, device): - """ - Load checkpoint if it exists. - - Returns: - tuple: (start_epoch, best_val_loss) - """ - ckpt_path = os.path.join(output_dir, f"latest_model_{model_name}.pth") - if not os.path.exists(ckpt_path): - print(f"No checkpoint found at {ckpt_path}. Starting from scratch.") - return 0, float("inf") - - print(f"Resuming from {ckpt_path}...") - checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) - - model.load_state_dict(checkpoint["model_state_dict"]) - if "optimizer_state_dict" in checkpoint: - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - if "scaler_state_dict" in checkpoint and scaler is not None: - scaler.load_state_dict(checkpoint["scaler_state_dict"]) - - start_epoch = checkpoint.get("epoch", -1) + 1 - best_val_loss = checkpoint.get("best_val_loss", float("inf")) - - print(f"Resumed at epoch {start_epoch + 1}") - return start_epoch, best_val_loss - - -# --------------------------------------------------------------------------- -# CLI Arguments -# --------------------------------------------------------------------------- - - -def parse_args(): - parser = argparse.ArgumentParser(description="Train Spatial TranscriptFormer") - - # Data - g = parser.add_argument_group("Data") - g.add_argument( - "--data-dir", - type=str, - default=get_config("data_dirs", ["hest_data"])[0], - help="Root directory of HEST data", - ) - g.add_argument( - "--feature-dir", - type=str, - default=None, - help="Explicit feature directory (overrides auto-detection)", - ) - g.add_argument( - "--num-genes", type=int, default=get_config("training.num_genes", 1000) - ) - g.add_argument( - "--max-samples", type=int, default=None, help="Limit samples for debugging" - ) - g.add_argument( - "--precomputed", action="store_true", help="Use pre-computed features" - ) - g.add_argument( - "--whole-slide", action="store_true", help="Dense whole-slide prediction" - ) - g.add_argument("--seed", type=int, default=42) - g.add_argument( - "--log-transform", action="store_true", help="Log1p transform targets" - ) - - # Loss - parser.add_argument( - "--loss", - type=str, - default="mse", - choices=["mse", "pcc", "mse_pcc", "zinb", "poisson", "logcosh"], - ) - parser.add_argument( - "--pcc-weight", - type=float, - default=1.0, - help="Weight for PCC term in mse_pcc loss", - ) - parser.add_argument( - "--pathway-loss-weight", - type=float, - default=0.0, - help="Weight for auxiliary pathway PCC loss (0 = disabled)", - ) - - # Model - g = parser.add_argument_group("Model") - g.add_argument( - "--model", - type=str, - default="he2rna", - choices=["he2rna", "vit_st", "interaction", "attention_mil", "transmil"], - ) - g.add_argument("--backbone", type=str, default="resnet50") - g.add_argument("--no-pretrained", action="store_false", dest="pretrained") - g.set_defaults(pretrained=True) - g.add_argument("--num-pathways", type=int, default=50) - g.add_argument("--token-dim", type=int, default=256) - g.add_argument("--n-heads", type=int, default=4) - g.add_argument("--n-layers", type=int, default=2) - g.add_argument( - "--no-spatial-pe", - action="store_false", - dest="use_spatial_pe", - help="Disable spatial positional encoding", - ) - g.set_defaults(use_spatial_pe=False) - g.add_argument( - "--interactions", - nargs="+", - default=None, - help="Attention interactions to enable: p2p, p2h, h2p, h2h (default: all)", - ) - - # Training - g = parser.add_argument_group("Training") - g.add_argument("--epochs", type=int, default=get_config("training.epochs", 10)) - g.add_argument( - "--batch-size", type=int, default=get_config("training.batch_size", 32) - ) - g.add_argument("--grad-accum-steps", type=int, default=1) - g.add_argument( - "--lr", type=float, default=get_config("training.learning_rate", 1e-4) - ) - g.add_argument("--weight-decay", type=float, default=0.0) - g.add_argument("--warmup-epochs", type=int, default=10) - g.add_argument("--sparsity-lambda", type=float, default=0.0) - g.add_argument("--augment", action="store_true") - g.add_argument("--use-amp", action="store_true") - g.add_argument( - "--output-dir", - type=str, - default=get_config("training.output_dir", "./checkpoints"), - ) - g.add_argument("--compile", action="store_true") - g.add_argument("--resume", action="store_true") - - # Advanced - g = parser.add_argument_group("Advanced") - g.add_argument("--n-neighbors", type=int, default=0) - g.add_argument("--use-global-context", action="store_true") - g.add_argument("--global-context-size", type=int, default=128) - g.add_argument("--compile-backend", type=str, default="inductor") - g.add_argument("--plot-pathways", action="store_true") - g.add_argument( - "--weak-supervision", action="store_true", help="Bag-level training for MIL" - ) - g.add_argument( - "--pathway-init", - action="store_true", - help="Initialize gene_reconstructor with MSigDB Hallmarks", - ) - g.add_argument( - "--pathways", - nargs="+", - default=None, - help="List of MSigDB pathway names to explicitly instantiate (e.g. HALLMARK_APOPTOSIS). If none are provided but --pathway-init is enabled, all pathways in the provided GMTs will be loaded.", - ) - g.add_argument( - "--custom-gmt", - nargs="+", - default=None, - help="List of URLs or local paths to custom .gmt files for pathway initialization. Overrides standard MSigDB defaults if provided.", - ) - - return parser.parse_args() - +from spatial_transcript_former.training.arguments import parse_args +from spatial_transcript_former.training.builder import setup_model, setup_criterion +from spatial_transcript_former.training.checkpoint import ( + save_checkpoint, + load_checkpoint, +) # --------------------------------------------------------------------------- # Main @@ -363,6 +44,25 @@ def main(): print(f"Device: {device}") set_seed(args.seed) + # Global gene count synchronization + genes_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "global_genes.json", + ) + if not os.path.exists(genes_path): + genes_path = "global_genes.json" + if os.path.exists(genes_path): + import json + + with open(genes_path, "r") as f: + gene_list = json.load(f) + args.num_genes = min(args.num_genes, len(gene_list)) + print(f"Validated global gene count: {args.num_genes}") + else: + print( + f"Warning: global_genes.json not found. Using requested num_genes={args.num_genes}" + ) + # 1. Data final_ids = get_sample_ids( args.data_dir, @@ -370,6 +70,7 @@ def main(): backbone=args.backbone, feature_dir=args.feature_dir, max_samples=args.max_samples, + organ=args.organ, ) np.random.shuffle(final_ids) @@ -395,16 +96,21 @@ def main(): # LR scheduler: cosine annealing with optional linear warmup warmup_epochs = args.warmup_epochs + warmup_scheduler = optim.lr_scheduler.LinearLR( + optimizer, start_factor=0.01, total_iters=max(1, warmup_epochs) + ) cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=args.epochs - warmup_epochs, eta_min=1e-6 + optimizer, T_max=max(1, args.epochs - warmup_epochs), eta_min=1e-6 ) - def lr_lambda(epoch): - if epoch < warmup_epochs: - return epoch / max(warmup_epochs, 1) - return 1.0 # cosine scheduler handles the rest - - warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + if warmup_epochs > 0: + main_scheduler = optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_scheduler, cosine_scheduler], + milestones=[warmup_epochs], + ) + else: + main_scheduler = cosine_scheduler scaler = torch.amp.GradScaler("cuda") if args.use_amp else None print(f"Loss: {criterion.__class__.__name__}") @@ -417,11 +123,20 @@ def lr_lambda(epoch): # 4. Resume start_epoch, best_val_loss = 0, float("inf") + schedulers = {"main": main_scheduler} if args.resume: - start_epoch, best_val_loss = load_checkpoint( - model, optimizer, scaler, args.output_dir, args.model, device + start_epoch, best_val_loss, loaded_schedulers = load_checkpoint( + model, optimizer, scaler, schedulers, args.output_dir, args.model, device ) + # Fallback for old checkpoints: manually step the scheduler to catch up + if start_epoch > 0 and main_scheduler.last_epoch < start_epoch: + print( + f"Old checkpoint detected. Manually stepping scheduler {start_epoch} times to catch up..." + ) + for _ in range(start_epoch): + main_scheduler.step() + # 5. Training Loop for epoch in range(start_epoch, args.epochs): print(f"\nEpoch {epoch + 1}/{args.epochs}") @@ -432,7 +147,6 @@ def lr_lambda(epoch): criterion, optimizer, device, - sparsity_lambda=args.sparsity_lambda, whole_slide=args.whole_slide, scaler=scaler, grad_accum_steps=args.grad_accum_steps, @@ -453,10 +167,7 @@ def lr_lambda(epoch): ) # Step LR scheduler - if epoch < warmup_epochs: - warmup_scheduler.step() - else: - cosine_scheduler.step() + main_scheduler.step() # Log epoch epoch_row = { @@ -472,6 +183,21 @@ def lr_lambda(epoch): epoch_row["pred_variance"] = round(val_metrics["pred_variance"], 6) if val_metrics.get("attn_correlation") is not None: epoch_row["attn_correlation"] = round(val_metrics["attn_correlation"], 4) + + # Hardware Resource Monitoring + try: + import psutil + + epoch_row["sys_cpu_percent"] = psutil.cpu_percent() + epoch_row["sys_ram_percent"] = psutil.virtual_memory().percent + except ImportError: + pass + + if torch.cuda.is_available(): + epoch_row["sys_gpu_mem_mb"] = round( + torch.cuda.memory_allocated() / (1024**2), 2 + ) + logger.log_epoch(epoch + 1, epoch_row) # Save best @@ -483,12 +209,21 @@ def lr_lambda(epoch): # Save latest save_checkpoint( - model, optimizer, scaler, epoch, best_val_loss, args.output_dir, args.model + model, + optimizer, + scaler, + schedulers, + epoch, + best_val_loss, + args.output_dir, + args.model, ) - # Periodic visualization (only when --plot-pathways is set) - if args.plot_pathways and val_ids: - run_inference_plot(model, args, val_ids[0], epoch, device) + # Periodic visualization + if val_ids and (epoch + 1) % args.vis_interval == 0: + vis_id = args.vis_sample if args.vis_sample else val_ids[0] + print(f"Generating visualization for sample {vis_id}...") + run_inference_plot(model, args, vis_id, epoch + 1, device) # 6. Finalize logger.finalize(best_val_loss) diff --git a/src/spatial_transcript_former/training/arguments.py b/src/spatial_transcript_former/training/arguments.py new file mode 100644 index 0000000..d125583 --- /dev/null +++ b/src/spatial_transcript_former/training/arguments.py @@ -0,0 +1,161 @@ +import argparse +from spatial_transcript_former.config import get_config + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train Spatial TranscriptFormer") + + # Data + g = parser.add_argument_group("Data") + g.add_argument( + "--data-dir", + type=str, + default=get_config("data_dirs", ["hest_data"])[0], + help="Root directory of HEST data", + ) + g.add_argument( + "--feature-dir", + type=str, + default=None, + help="Explicit feature directory (overrides auto-detection)", + ) + g.add_argument( + "--num-genes", type=int, default=get_config("training.num_genes", 1000) + ) + g.add_argument( + "--max-samples", type=int, default=None, help="Limit samples for debugging" + ) + g.add_argument( + "--precomputed", action="store_true", help="Use pre-computed features" + ) + g.add_argument( + "--whole-slide", action="store_true", help="Dense whole-slide prediction" + ) + g.add_argument("--seed", type=int, default=42) + g.add_argument( + "--log-transform", action="store_true", help="Log1p transform targets" + ) + g.add_argument("--organ", type=str, default=None, help="Filter samples by organ") + + # Loss + parser.add_argument( + "--loss", + type=str, + default="mse_pcc", + choices=["mse", "pcc", "mse_pcc", "zinb", "poisson", "logcosh"], + ) + parser.add_argument( + "--pcc-weight", + type=float, + default=1.0, + help="Weight for PCC term in mse_pcc loss", + ) + parser.add_argument( + "--pathway-loss-weight", + type=float, + default=0.0, + help="Weight for auxiliary pathway PCC loss (0 = disabled)", + ) + + # Model + g = parser.add_argument_group("Model") + g.add_argument( + "--model", + type=str, + default="he2rna", + choices=["he2rna", "vit_st", "interaction", "attention_mil", "transmil"], + ) + g.add_argument("--backbone", type=str, default="resnet50") + g.add_argument("--no-pretrained", action="store_false", dest="pretrained") + g.set_defaults(pretrained=True) + g.add_argument("--num-pathways", type=int, default=50) + g.add_argument("--token-dim", type=int, default=256) + g.add_argument("--n-heads", type=int, default=4) + g.add_argument("--n-layers", type=int, default=2) + g.add_argument( + "--use-spatial-pe", + action="store_true", + help="Enable spatial positional encoding", + ) + g.add_argument( + "--interactions", + nargs="+", + default=None, + help="Attention interactions to enable: p2p, p2h, h2p, h2h (default: all)", + ) + + # Training + g = parser.add_argument_group("Training") + g.add_argument("--epochs", type=int, default=get_config("training.epochs", 10)) + g.add_argument( + "--batch-size", type=int, default=get_config("training.batch_size", 32) + ) + g.add_argument("--grad-accum-steps", type=int, default=1) + g.add_argument( + "--lr", type=float, default=get_config("training.learning_rate", 1e-4) + ) + g.add_argument("--weight-decay", type=float, default=0.0) + g.add_argument("--warmup-epochs", type=int, default=10) + g.add_argument("--augment", action="store_true") + g.add_argument("--use-amp", action="store_true") + g.add_argument( + "--output-dir", + type=str, + default=get_config("training.output_dir", "./checkpoints"), + ) + g.add_argument("--compile", action="store_true") + g.add_argument("--resume", action="store_true") + + # Advanced + g = parser.add_argument_group("Advanced") + g.add_argument("--n-neighbors", type=int, default=0) + g.add_argument("--use-global-context", action="store_true") + g.add_argument("--global-context-size", type=int, default=128) + g.add_argument("--compile-backend", type=str, default="inductor") + g.add_argument("--plot-pathways", action="store_true") + g.add_argument( + "--plot-pathways-list", + nargs="+", + default=None, + help="List of pathway names to exclusively visualize (e.g. HALLMARK_HYPOXIA). Defaults to the first 6 if None.", + ) + g.add_argument("--plot-attention", action="store_true") + g.add_argument( + "--return-attention", + action="store_true", + help="Extract and return attention maps during forward pass", + ) + g.add_argument( + "--weak-supervision", action="store_true", help="Bag-level training for MIL" + ) + g.add_argument( + "--pathway-init", + action="store_true", + help="Initialize gene_reconstructor with MSigDB Hallmarks", + ) + g.add_argument( + "--pathways", + nargs="+", + default=None, + help="List of MSigDB pathway names to explicitly instantiate (e.g. HALLMARK_APOPTOSIS). If none are provided but --pathway-init is enabled, all pathways in the provided GMTs will be loaded.", + ) + g.add_argument( + "--custom-gmt", + nargs="+", + default=None, + help="List of URLs or local paths to custom .gmt files for pathway initialization. Overrides standard MSigDB defaults if provided.", + ) + g.add_argument( + "--vis-interval", + type=int, + default=1, + help="Epoch interval for generating validation plots", + ) + g.add_argument( + "--vis-sample", + type=str, + default=None, + help="Sample ID to use for periodic visualization", + ) + + return parser.parse_args() diff --git a/src/spatial_transcript_former/training/builder.py b/src/spatial_transcript_former/training/builder.py new file mode 100644 index 0000000..6052409 --- /dev/null +++ b/src/spatial_transcript_former/training/builder.py @@ -0,0 +1,140 @@ +import os +import torch +import torch.nn as nn +from spatial_transcript_former.models import HE2RNA, ViT_ST, SpatialTranscriptFormer +from spatial_transcript_former.training.losses import ( + PCCLoss, + CompositeLoss, + MaskedMSELoss, + ZINBLoss, +) + + +def setup_model(args, device): + """Initialize and optionally compile the model.""" + if args.model == "he2rna": + model = HE2RNA( + num_genes=args.num_genes, backbone=args.backbone, pretrained=args.pretrained + ) + elif args.model == "vit_st": + model = ViT_ST( + num_genes=args.num_genes, + model_name=args.backbone if "vit_" in args.backbone else "vit_b_16", + pretrained=args.pretrained, + ) + elif args.model == "interaction": + print( + f"Initializing SpatialTranscriptFormer ({args.backbone}, pretrained={args.pretrained})" + ) + + # Load biological pathway initialization if requested + pathway_init = None + if getattr(args, "pathway_init", False): + from spatial_transcript_former.data.pathways import ( + get_pathway_init, + MSIGDB_URLS, + ) + import json + + genes_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "global_genes.json", + ) + if not os.path.exists(genes_path): + genes_path = "global_genes.json" + with open(genes_path) as f: + gene_list = json.load(f) + + if getattr(args, "custom_gmt", None): + urls = args.custom_gmt + elif getattr(args, "pathways", None): + # If specific pathways requested but no custom GMT, search standard collections + urls = [ + MSIGDB_URLS["hallmarks"], + MSIGDB_URLS["c2_medicus"], + MSIGDB_URLS["c2_cgp"], + ] + else: + # Default to just the 50 Hallmarks to prevent VRAM exhaustion + urls = [MSIGDB_URLS["hallmarks"]] + + pathway_init, pathway_names = get_pathway_init( + gene_list[: args.num_genes], gmt_urls=urls, filter_names=args.pathways + ) + # Override num_pathways based on actual parsed paths + args.num_pathways = len(pathway_names) + print(f"Num pathways forced to {args.num_pathways} based on init dict") + + model = SpatialTranscriptFormer( + num_genes=args.num_genes, + backbone_name=args.backbone, + pretrained=args.pretrained, + token_dim=args.token_dim, + n_heads=args.n_heads, + n_layers=args.n_layers, + num_pathways=args.num_pathways, + pathway_init=pathway_init, + use_spatial_pe=args.use_spatial_pe, + output_mode="zinb" if args.loss == "zinb" else "counts", + interactions=getattr(args, "interactions", None), + ) + elif args.model == "attention_mil": + from spatial_transcript_former.models.mil import AttentionMIL + + model = AttentionMIL( + output_dim=args.num_genes, + backbone_name=args.backbone, + pretrained=args.pretrained, + ) + elif args.model == "transmil": + from spatial_transcript_former.models.mil import TransMIL + + model = TransMIL( + output_dim=args.num_genes, + backbone_name=args.backbone, + pretrained=args.pretrained, + ) + else: + raise ValueError(f"Unknown model: {args.model}") + + model.weak_supervision = getattr(args, "weak_supervision", False) + model = model.to(device) + + if args.compile: + print(f"Compiling model (backend='{args.compile_backend}')...") + try: + model = torch.compile(model, backend=args.compile_backend) + except Exception as e: + print(f"Compilation failed: {e}. Using eager mode.") + + return model + + +def setup_criterion(args, pathway_init=None): + """Create loss function from CLI args. + + If ``pathway_init`` is provided and ``--pathway-loss-weight > 0``, + wraps the base criterion with :class:`AuxiliaryPathwayLoss`. + """ + if args.loss == "pcc": + base = PCCLoss() + elif args.loss == "mse_pcc": + base = CompositeLoss(alpha=args.pcc_weight) + elif args.loss == "zinb": + base = ZINBLoss() + elif args.loss == "poisson": + base = nn.PoissonNLLLoss(log_input=True) + elif args.loss == "logcosh": + print("Using HuberLoss as proxy for LogCosh") + base = nn.HuberLoss() + else: + base = MaskedMSELoss() + + pw_weight = getattr(args, "pathway_loss_weight", 0.0) + if pathway_init is not None and pw_weight > 0: + from spatial_transcript_former.training.losses import AuxiliaryPathwayLoss + + print(f"Wrapping criterion with AuxiliaryPathwayLoss (lambda={pw_weight})") + return AuxiliaryPathwayLoss(pathway_init, base, lambda_pathway=pw_weight) + + return base diff --git a/src/spatial_transcript_former/training/checkpoint.py b/src/spatial_transcript_former/training/checkpoint.py new file mode 100644 index 0000000..e5ff772 --- /dev/null +++ b/src/spatial_transcript_former/training/checkpoint.py @@ -0,0 +1,76 @@ +import os +import torch + + +def save_checkpoint( + model, optimizer, scaler, schedulers, epoch, best_val_loss, output_dir, model_name +): + """Save training state for resuming.""" + save_dict = { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "best_val_loss": best_val_loss, + } + if scaler is not None: + save_dict["scaler_state_dict"] = scaler.state_dict() + if schedulers is not None: + save_dict["schedulers_state_dict"] = { + k: v.state_dict() for k, v in schedulers.items() + } + + torch.save(save_dict, os.path.join(output_dir, f"latest_model_{model_name}.pth")) + + +def load_checkpoint( + model, optimizer, scaler, schedulers, output_dir, model_name, device +): + """ + Load checkpoint if it exists. + + Returns: + tuple: (start_epoch, best_val_loss, loaded_schedulers) + """ + ckpt_path = os.path.join(output_dir, f"latest_model_{model_name}.pth") + if not os.path.exists(ckpt_path): + print(f"No checkpoint found at {ckpt_path}. Starting from scratch.") + return 0, float("inf"), False + + print(f"Resuming from {ckpt_path}...") + try: + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) + except (EOFError, RuntimeError, Exception) as e: + print( + f"Failed to load checkpoint at {ckpt_path} due to error: {e}. Starting from scratch." + ) + return 0, float("inf"), False + + incompatible_keys = model.load_state_dict( + checkpoint["model_state_dict"], strict=False + ) + if incompatible_keys.missing_keys or incompatible_keys.unexpected_keys: + print(f"Loaded with incompatible keys: {incompatible_keys}") + try: + if "optimizer_state_dict" in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + if "scaler_state_dict" in checkpoint and scaler is not None: + scaler.load_state_dict(checkpoint["scaler_state_dict"]) + + if "schedulers_state_dict" in checkpoint and schedulers is not None: + for k, v in checkpoint["schedulers_state_dict"].items(): + if k in schedulers: + schedulers[k].load_state_dict(v) + loaded_schedulers = True + else: + loaded_schedulers = False + except (ValueError, Exception) as e: + print( + f"Failed to load optimizer/scheduler states due to architecture change ({e}). Starting from scratch." + ) + return 0, float("inf"), False + + start_epoch = checkpoint.get("epoch", -1) + 1 + best_val_loss = checkpoint.get("best_val_loss", float("inf")) + + print(f"Resumed at epoch {start_epoch + 1}") + return start_epoch, best_val_loss, loaded_schedulers diff --git a/src/spatial_transcript_former/training/engine.py b/src/spatial_transcript_former/training/engine.py index 7319b48..3a35561 100644 --- a/src/spatial_transcript_former/training/engine.py +++ b/src/spatial_transcript_former/training/engine.py @@ -49,7 +49,6 @@ def train_one_epoch( criterion, optimizer, device, - sparsity_lambda=0.0, whole_slide=False, scaler=None, grad_accum_steps=1, @@ -103,9 +102,6 @@ def train_one_epoch( bag_target = _compute_bag_target(genes, mask) loss = criterion(preds, bag_target) - if sparsity_lambda > 0 and hasattr(model, "get_sparsity_loss"): - loss = loss + (sparsity_lambda * model.get_sparsity_loss()) - loss = loss / grad_accum_steps _optimizer_step( @@ -128,9 +124,6 @@ def train_one_epoch( loss = criterion(outputs, targets) - if sparsity_lambda > 0 and hasattr(model, "get_sparsity_loss"): - loss = loss + (sparsity_lambda * model.get_sparsity_loss()) - loss = loss / grad_accum_steps _optimizer_step( diff --git a/src/spatial_transcript_former/training/experiment_logger.py b/src/spatial_transcript_former/training/experiment_logger.py index 783c278..b2de28d 100644 --- a/src/spatial_transcript_former/training/experiment_logger.py +++ b/src/spatial_transcript_former/training/experiment_logger.py @@ -6,19 +6,19 @@ """ import os -import csv import json import time +import sqlite3 from datetime import datetime from typing import Any, Dict, Optional class ExperimentLogger: """ - Logs training metrics to CSV and writes a JSON summary at the end. + Logs training metrics to a SQLite database and writes a JSON summary at the end. Output files: - - training_log.csv: Per-epoch metrics (epoch, train_loss, val_loss, ...) + - training_logs.sqlite: Per-epoch metrics (epoch, train_loss, val_loss, ...) stored in a table `metrics`. - results_summary.json: Full config + final metrics """ @@ -30,15 +30,47 @@ def __init__(self, output_dir: str, config: Dict[str, Any]): """ self.output_dir = output_dir self.config = config - self.csv_path = os.path.join(output_dir, "training_log.csv") + self.db_path = os.path.join(output_dir, "training_logs.sqlite") self.json_path = os.path.join(output_dir, "results_summary.json") self.start_time = time.time() self.epoch_metrics = [] - self._csv_header_written = os.path.exists(self.csv_path) + + self._init_db() + + def _init_db(self): + """Initializes the SQLite database and metric table if it doesn't exist.""" + # Using connect as a context manager ensures commits + with sqlite3.connect(self.db_path) as conn: + # We use a dynamic schema where columns are added as needed. + # Start with just 'epoch' as the primary key. + conn.execute( + """ + CREATE TABLE IF NOT EXISTS metrics ( + epoch INTEGER PRIMARY KEY + ) + """ + ) + + def _ensure_columns(self, metrics: Dict[str, float]): + """Ensures all metric keys exist as columns in the metrics table.""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("PRAGMA table_info(metrics)") + existing_columns = {col[1] for col in cursor.fetchall()} + + for key in metrics.keys(): + if key not in existing_columns: + # SQLite alters don't fail if concurrent unless locked. + # Try to add missing column as REAL (float) + try: + cursor.execute(f"ALTER TABLE metrics ADD COLUMN {key} REAL") + except sqlite3.OperationalError: + # Might have been added by another process if we are running distributed + pass def log_epoch(self, epoch: int, metrics: Dict[str, float]): """ - Append one row to training_log.csv. + Insert one row into training_logs.sqlite -> metrics table. Args: epoch: Current epoch number (1-indexed). @@ -47,15 +79,19 @@ def log_epoch(self, epoch: int, metrics: Dict[str, float]): row = {"epoch": epoch, **metrics} self.epoch_metrics.append(row) - # Determine fieldnames from first row - fieldnames = list(row.keys()) + # Ensure all columns exist before inserting + self._ensure_columns(metrics) + + columns = ", ".join(row.keys()) + placeholders = ", ".join(["?"] * len(row)) + values = tuple(row.values()) - with open(self.csv_path, "a", newline="") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - if not self._csv_header_written: - writer.writeheader() - self._csv_header_written = True - writer.writerow(row) + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute( + f"INSERT OR REPLACE INTO metrics ({columns}) VALUES ({placeholders})", + values, + ) def finalize( self, best_val_loss: float, extra_metrics: Optional[Dict[str, Any]] = None diff --git a/src/spatial_transcript_former/training/losses.py b/src/spatial_transcript_former/training/losses.py index e9a3bc2..bb5bafe 100644 --- a/src/spatial_transcript_former/training/losses.py +++ b/src/spatial_transcript_former/training/losses.py @@ -327,10 +327,51 @@ def forward(self, gene_preds, target_genes, mask=None, pathway_preds=None): return gene_loss # Compute pathway ground truth from gene expression - # target_genes: (B, [N,] G), pathway_matrix: (P, G) - # result: (B, [N,] P) + # 1. Spatially standardize (Z-score) the target genes to ensure equal weighting with torch.no_grad(): - target_pathways = torch.matmul(target_genes, self.pathway_matrix.T) + if target_genes.dim() == 2: + # Patch level: (B, G). Normalize across the batch dimension (which acts as spatial context) + if target_genes.shape[0] > 1: + means = target_genes.mean(dim=0, keepdim=True) + stds = target_genes.std(dim=0, keepdim=True).clamp(min=1e-6) + norm_genes = (target_genes - means) / stds + else: + norm_genes = torch.zeros_like(target_genes) + else: + # Whole slide: (B, N, G). Normalize across valid spatial positions N + if mask is not None: + valid_mask = ~mask.unsqueeze(-1) # (B, N, 1) + valid_counts = valid_mask.sum(dim=1, keepdim=True).clamp( + min=1.0 + ) # (B, 1, 1) + + means = (target_genes * valid_mask.float()).sum( + dim=1, keepdim=True + ) / valid_counts + + # Compute variance explicitly to handle masking correctly + diffs = (target_genes - means) * valid_mask.float() + vars = (diffs**2).sum(dim=1, keepdim=True) / ( + valid_counts - 1 + ).clamp(min=1.0) + stds = torch.sqrt(vars).clamp(min=1e-6) + + norm_genes = diffs / stds + norm_genes = norm_genes * valid_mask.float() + else: + means = target_genes.mean(dim=1, keepdim=True) + stds = target_genes.std(dim=1, keepdim=True).clamp(min=1e-6) + norm_genes = (target_genes - means) / stds + + # 2. Project normalized genes onto the pathway matrix + # target_pathways: (B, P) or (B, N, P) + target_pathways = torch.matmul(norm_genes, self.pathway_matrix.T) + + # 3. Average by the number of genes in each pathway + member_counts = self.pathway_matrix.sum(dim=1, keepdim=True).T.clamp( + min=1.0 + ) + target_pathways = target_pathways / member_counts pathway_loss = self.pcc(pathway_preds, target_pathways, mask=mask) diff --git a/src/spatial_transcript_former/visualization.py b/src/spatial_transcript_former/visualization.py index 2aaef7a..98a670a 100644 --- a/src/spatial_transcript_former/visualization.py +++ b/src/spatial_transcript_former/visualization.py @@ -2,31 +2,25 @@ import torch import numpy as np import h5py -from torch.utils.data import DataLoader -from spatial_transcript_former.data.dataset import ( - load_gene_expression_matrix, - HEST_FeatureDataset, - HEST_Dataset, - load_global_genes, -) -from spatial_transcript_former.models import SpatialTranscriptFormer -from spatial_transcript_former.predict import plot_training_summary +import matplotlib.pyplot as plt +from spatial_transcript_former.data.utils import setup_dataloaders def _load_histology(h5ad_path): - """Load the downscaled histology image from an h5ad file. - - Returns: - tuple: (image_array, scale_factor) or (None, None) on failure. + """ + Load the downscaled histology image from an h5ad file. + Returns: (image_array, scale_factor) or (None, 1.0) on failure. """ try: + import h5py + with h5py.File(h5ad_path, "r") as f: if "uns" not in f or "spatial" not in f["uns"]: - return None, None + return None, 1.0 spatial = f["uns/spatial"] sample_key = list(spatial.keys())[0] if len(spatial.keys()) > 0 else None if sample_key is None: - return None, None + return None, 1.0 img_group = spatial[sample_key]["images"] img_key = ( "downscaled_fullres" @@ -40,26 +34,16 @@ def _load_histology(h5ad_path): if "tissue_downscaled_fullres_scalef" in scale_group else list(scale_group.keys())[0] ) - scalef = scale_group[scale_key][()] + scalef = float(scale_group[scale_key][()]) return img, scalef except Exception as e: print(f"Warning: Could not load histology: {e}") - return None, None - - -def _compute_pathway_truth(gene_truth, gene_names, args=None): - """Compute pathway ground truth from gene expression using MSigDB membership. + return None, 1.0 - For each pathway, computes the mean expression of its member genes. - This is independent of model weights and consistent across epochs. - Args: - gene_truth: (N, G) gene expression matrix (log-transformed if applicable). - gene_names: List of gene names (length G). - args: Optional CLI args to extract pathway_init filters. - - Returns: - tuple: (pathway_truth (N, P), pathway_names list) or (None, None). +def _compute_pathway_truth(gene_truth, gene_names, args): + """ + Compute pathway ground truth from gene expression using MSigDB membership. """ try: from spatial_transcript_former.data.pathways import ( @@ -67,280 +51,166 @@ def _compute_pathway_truth(gene_truth, gene_names, args=None): MSIGDB_URLS, ) - filter_names = None - urls = None - if args is not None and getattr(args, "pathway_init", False): - if getattr(args, "custom_gmt", None): - urls = args.custom_gmt - else: - urls = [ - MSIGDB_URLS["hallmarks"], - MSIGDB_URLS["c2_medicus"], - MSIGDB_URLS["c2_cgp"], - ] - filter_names = getattr(args, "pathways", None) - - pw_matrix, pw_names = get_pathway_init( - gene_names, gmt_urls=urls, filter_names=filter_names, verbose=False - ) - pw_np = pw_matrix.numpy() # (P, G) binary membership - member_counts = pw_np.sum(axis=1, keepdims=True).clip(min=1) # (P, 1) - # Mean expression of member genes per pathway - pathway_truth = (gene_truth @ pw_np.T) / member_counts.T # (N, P) - return pathway_truth, pw_names + urls = [MSIGDB_URLS["hallmarks"]] + # Only use hallmarks for periodic visualization to keep it fast + pw_matrix, pw_names = get_pathway_init(gene_names, gmt_urls=urls, verbose=False) + pw_np = pw_matrix.numpy() # (P, G) + + # Z-score normalize gene spatial patterns to match AuxiliaryPathwayLoss + gene_truth = gene_truth.astype(np.float64) + means = np.mean(gene_truth, axis=0, keepdims=True) + stds = np.std(gene_truth, axis=0, keepdims=True) + stds[stds < 1e-6] = 1e-6 # prevent division by zero + norm_genes = (gene_truth - means) / stds + + member_counts = pw_np.sum(axis=1, keepdims=True).clip(min=1) + # Mean expression of normalized member genes per pathway + pathway_truth = (norm_genes @ pw_np.T) / member_counts.T # (N, P) + return pathway_truth.astype(np.float32), pw_names except Exception as e: print(f"Warning: Could not compute pathway ground truth: {e}") - import traceback - - traceback.print_exc() return None, None -def _get_pathway_names(): - """Get MSigDB Hallmark pathway names. - - Returns: - list: Pathway names, or None on failure. - """ - try: - from spatial_transcript_former.data.pathways import ( - download_msigdb_gmt, - parse_gmt, - MSIGDB_URLS, - ) - - url = MSIGDB_URLS["hallmarks"] - filename = url.split("/")[-1] - gmt_path = download_msigdb_gmt(url, filename, ".cache") - pathway_dict = parse_gmt(gmt_path) - return list(pathway_dict.keys()) - except Exception: - return None - - def run_inference_plot(model, args, sample_id, epoch, device): """ - Run inference on a single sample and save a unified pathway visualization. - - Produces a single figure per epoch showing histology + core - pathways (ground truth vs prediction), where ground truth is computed by - projecting true gene expression through the model's gene_reconstructor - via pseudo-inverse so both live in the same activation space. - - Args: - model (nn.Module): The model to use. - args (argparse.Namespace): CLI arguments. - sample_id (str): ID of the sample to plot. - epoch (int): Current epoch (for filename). - device (torch.device): Device to run inference on. + Generates a high-quality spatial visualization of pathway predictions. """ - try: - plot_pathways = getattr(args, "plot_pathways", False) - if not plot_pathways: - return - - log_transform = getattr(args, "log_transform", False) - - with torch.no_grad(): - model.eval() - - # Setup paths - patches_dir = ( - os.path.join(args.data_dir, "patches") - if os.path.isdir(os.path.join(args.data_dir, "patches")) - else args.data_dir - ) - st_dir = os.path.join(args.data_dir, "st") - h5_path = os.path.join(patches_dir, f"{sample_id}.h5") - h5ad_path = os.path.join(st_dir, f"{sample_id}.h5ad") - - # Load global genes - try: - common_gene_names = load_global_genes(args.data_dir, args.num_genes) - except Exception: - common_gene_names = None - - # Run inference - preds = [] - pathways_list = [] - - if args.precomputed: - feat_dir_name = ( - "he_features" - if args.backbone == "resnet50" - else f"he_features_{args.backbone}" - ) - feature_path = os.path.join( - args.data_dir, feat_dir_name, f"{sample_id}.pt" - ) - if not os.path.exists(feature_path): - feature_path = os.path.join( - args.data_dir, "patches", feat_dir_name, f"{sample_id}.pt" - ) - - ds = HEST_FeatureDataset( - feature_path, - h5ad_path, - num_genes=args.num_genes, - selected_gene_names=common_gene_names, - n_neighbors=args.n_neighbors, - whole_slide_mode=args.whole_slide, - log1p=log_transform, - ) - - if args.whole_slide: - feats, gene_targets, slide_coords = ds[0] - feats = feats.unsqueeze(0).to(device) - slide_coords = slide_coords.unsqueeze(0).to(device) - if isinstance(model, SpatialTranscriptFormer): - output = model( - feats, - return_dense=True, - rel_coords=slide_coords, - return_pathways=True, - ) - if isinstance(output, tuple): - out_preds = output[0] - if isinstance(out_preds, tuple): - out_preds = out_preds[1] - preds.append(out_preds.detach().cpu().squeeze(0)) - pathways_list.append(output[1].detach().cpu().squeeze(0)) - else: - preds.append(output.detach().cpu().squeeze(0)) - - # Use raw pixel coords from the .pt file for histology overlay - # These are guaranteed to be in the same order as the features - saved_data = torch.load( - feature_path, map_location="cpu", weights_only=True - ) - raw_coords = saved_data["coords"] # (N, 2) - raw_barcodes = saved_data["barcodes"] - del saved_data - - # Compute the same mask the dataset used to filter - _, pt_mask, gene_names = load_gene_expression_matrix( - h5ad_path, - raw_barcodes, - selected_gene_names=common_gene_names, - num_genes=args.num_genes, - ) - pt_mask_bool = np.array(pt_mask, dtype=bool) - coord_subset = raw_coords[pt_mask_bool].numpy() - - # Pathway truth from the dataset's aligned gene matrix - gene_truth = gene_targets.numpy() - pathway_truth, pathway_names = _compute_pathway_truth( - gene_truth, gene_names, args=args - ) - else: - dl = DataLoader(ds, batch_size=32, shuffle=False) - for feats, _, rel_coords_batch in dl: - if isinstance(model, SpatialTranscriptFormer): - output = model( - feats.to(device), - rel_coords=rel_coords_batch.to(device), - return_pathways=True, - ) - if isinstance(output, tuple): - pathways_list.append(output[1].cpu()) - out_preds = output[0] - if isinstance(out_preds, tuple): - out_preds = out_preds[1] - preds.append(out_preds.cpu()) - else: - preds.append(output.cpu()) - else: - preds.append(model(feats.to(device)).cpu()) - - # Non-whole-slide: use h5 file coords (same source as DataLoader) - with h5py.File(h5_path, "r") as f: - patch_barcodes = f["barcode"][:].flatten() - h5_coords = f["coords"][:] - gene_matrix, mask, gene_names = load_gene_expression_matrix( - h5ad_path, - patch_barcodes, - selected_gene_names=common_gene_names, - num_genes=args.num_genes, - ) - coord_mask = np.array(mask, dtype=bool) - coord_subset = h5_coords[coord_mask] - gene_truth = np.log1p(gene_matrix) if log_transform else gene_matrix - pathway_truth, pathway_names = _compute_pathway_truth( - gene_truth, gene_names, args=args - ) + from spatial_transcript_former.predict import plot_training_summary + + # 1. Setup Data + _, val_loader = setup_dataloaders(args, [], [sample_id]) + if val_loader is None: + return + + model.eval() + preds_list = [] + pathways_list = [] + targets_list = [] + coords_list = [] + masks_list = [] + + # 2. Run Inference + with torch.no_grad(): + for batch in val_loader: + if args.whole_slide: + image_features, target, coords, mask = batch + image_features = image_features.to(device) + coords = coords.to(device) + mask = mask.to(device) + target = target.to(device) else: - with h5py.File(h5_path, "r") as f: - patch_barcodes = f["barcode"][:].flatten() - h5_coords = f["coords"][:] - gene_matrix, mask, gene_names = load_gene_expression_matrix( - h5ad_path, - patch_barcodes, - selected_gene_names=common_gene_names, - num_genes=args.num_genes, - ) - coord_mask = np.array(mask, dtype=bool) - coord_subset = h5_coords[coord_mask] - ds = HEST_Dataset( - h5_path, coord_subset, gene_matrix, indices=np.where(mask)[0] - ) - dl = DataLoader(ds, batch_size=32, shuffle=False) - for imgs, _, rel_coords_batch in dl: - if isinstance(model, SpatialTranscriptFormer): - output = model( - imgs.to(device), - rel_coords=rel_coords_batch.to(device), - return_pathways=True, - ) - if isinstance(output, tuple): - pathways_list.append(output[1].cpu()) - out_preds = output[0] - if isinstance(out_preds, tuple): - out_preds = out_preds[1] - preds.append(out_preds.cpu()) - else: - preds.append(output.cpu()) - else: - preds.append(model(imgs.to(device)).cpu()) - gene_truth = np.log1p(gene_matrix) if log_transform else gene_matrix - pathway_truth, pathway_names = _compute_pathway_truth( - gene_truth, gene_names, args=args + image_features, target, coords = batch + image_features = image_features.to(device) + coords = coords.to(device) + mask = torch.ones(target.shape[0], target.shape[1], device=device) + target = target.to(device) + + # Forward pass + if args.whole_slide: + outputs = model( + image_features, rel_coords=coords, mask=mask, return_dense=True ) + else: + outputs = model(image_features, rel_coords=coords) - if not preds: - print("Warning: No predictions generated. Skipping plot.") - return - - # Compute pathway activations from gene predictions (same method as truth) - # Both truth and pred are now: mean gene expression of pathway members - gene_preds_np = torch.cat(preds, dim=0).numpy() - pathway_pred, _ = _compute_pathway_truth( - gene_preds_np, gene_names, args=args - ) - - if pathway_truth is None or pathway_pred is None: - print("Warning: Could not compute pathway truth/pred. Skipping plot.") - return - - # Load histology image - histology_img, scalef = _load_histology(h5ad_path) - - # Generate unified figure - save_path = os.path.join( - args.output_dir, f"{sample_id}_epoch_{epoch+1}.png" - ) - plot_training_summary( - coord_subset, - pathway_pred, - pathway_truth, - pathway_names, - sample_id=sample_id, - histology_img=histology_img, - scalef=scalef, - save_path=save_path, - ) - - except Exception as e: - print(f"Warning: Failed to generate validation plot: {e}") - import traceback - - traceback.print_exc() + # The model might return a tuple if pathways are enabled + if isinstance(outputs, tuple): + pred_counts = outputs[0] + pred_pathways = outputs[1] if len(outputs) > 1 else None + else: + pred_counts = outputs + pred_pathways = None + + preds_list.append(pred_counts.cpu()) + if pred_pathways is not None: + pathways_list.append(pred_pathways.cpu()) + targets_list.append(target.cpu()) + coords_list.append(coords.cpu()) + masks_list.append(mask.cpu()) + + if args.whole_slide: + break # Whole slide is one batch + + # Concatenate results (for patch-based) + all_preds = torch.cat(preds_list, dim=1 if args.whole_slide else 0) + all_targets = torch.cat(targets_list, dim=1 if args.whole_slide else 0) + all_coords = torch.cat(coords_list, dim=1 if args.whole_slide else 0) + all_masks = torch.cat(masks_list, dim=1 if args.whole_slide else 0) + + if pathways_list: + all_pathways = torch.cat(pathways_list, dim=1 if args.whole_slide else 0) + else: + all_pathways = None + + # Squeeze batch dim for processing + pred_counts = all_preds.numpy()[0] + target_genes = all_targets.numpy()[0] + coords = all_coords.numpy()[0] + mask = all_masks.numpy()[0] + + if all_pathways is not None: + pathway_preds = all_pathways.numpy()[0] + else: + pathway_preds = None + + # Un-log if necessary to get absolute counts + if getattr(args, "log_transform", False): + pred_counts = np.expm1(pred_counts) + target_genes = np.expm1(target_genes) + if pathway_preds is not None: + pathway_preds = np.expm1(pathway_preds) + + # 3. Filter Valid Spots + if args.whole_slide: + valid_idx = ~mask.astype(bool) + else: + valid_idx = mask.astype(bool) + + coords = coords[valid_idx] + pred_counts = pred_counts[valid_idx] + target_genes = target_genes[valid_idx] + if pathway_preds is not None: + pathway_preds = pathway_preds[valid_idx] + + if len(coords) == 0: + return + + # 4. Compute Pathway Truth + from spatial_transcript_former.data.dataset import load_global_genes + + gene_names = load_global_genes(args.data_dir, args.num_genes) + + # Pathway truth calculation + pathway_truth, pathway_names = _compute_pathway_truth( + target_genes, gene_names, args + ) + + if pathway_truth is None: + print("Warning: Could not compute pathway truth. Visualization skipped.") + return + + # If the model didn't return pathways directly, use predicted genes to compute them + if pathway_preds is None: + pathway_preds, _ = _compute_pathway_truth(pred_counts, gene_names, args) + + # 5. Load Histology + st_dir = os.path.join(args.data_dir, "st") + h5ad_path = os.path.join(st_dir, f"{sample_id}.h5ad") + histology_img, scalef = _load_histology(h5ad_path) + + # 6. Plot + os.makedirs(args.output_dir, exist_ok=True) + save_path = os.path.join(args.output_dir, f"{sample_id}_epoch_{epoch}.png") + + plot_training_summary( + coords, + pathway_preds, + pathway_truth, + pathway_names, + sample_id=sample_id, + histology_img=histology_img, + scalef=scalef, + save_path=save_path, + plot_pathways_list=getattr(args, "plot_pathways_list", None), + ) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 08fe63b..a9eb240 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -58,6 +58,7 @@ def test_save_load_preserves_weights(self, small_model, checkpoint_dir): small_model, optimizer, None, + None, # schedulers epoch=42, best_val_loss=0.123, output_dir=checkpoint_dir, @@ -74,8 +75,14 @@ def test_save_load_preserves_weights(self, small_model, checkpoint_dir): ) fresh_optimizer = optim.Adam(fresh_model.parameters(), lr=1e-4) - start_epoch, best_val = load_checkpoint( - fresh_model, fresh_optimizer, None, checkpoint_dir, "interaction", "cpu" + start_epoch, best_val, loaded_schedulers = load_checkpoint( + fresh_model, + fresh_optimizer, + None, + None, + checkpoint_dir, + "interaction", + "cpu", ) # Verify metadata @@ -98,6 +105,7 @@ def test_save_load_preserves_scaler(self, small_model, checkpoint_dir): small_model, optimizer, scaler, + None, # schedulers epoch=10, best_val_loss=0.5, output_dir=checkpoint_dir, @@ -118,6 +126,7 @@ def test_save_load_preserves_scaler(self, small_model, checkpoint_dir): fresh_model, fresh_optimizer, fresh_scaler, + None, # schedulers checkpoint_dir, "interaction", "cpu", @@ -129,8 +138,8 @@ def test_save_load_preserves_scaler(self, small_model, checkpoint_dir): def test_no_checkpoint_starts_fresh(self, small_model, checkpoint_dir): """Missing checkpoint should return epoch 0 and inf loss.""" optimizer = optim.Adam(small_model.parameters(), lr=1e-4) - start_epoch, best_val = load_checkpoint( - small_model, optimizer, None, checkpoint_dir, "nonexistent", "cpu" + start_epoch, best_val, loaded_schedulers = load_checkpoint( + small_model, optimizer, None, None, checkpoint_dir, "nonexistent", "cpu" ) assert start_epoch == 0 assert best_val == float("inf") diff --git a/tests/test_checkpoints.py b/tests/test_checkpoints.py index 028ec10..f454c04 100644 --- a/tests/test_checkpoints.py +++ b/tests/test_checkpoints.py @@ -27,7 +27,19 @@ def test_model_structure_consistency(): assert model.gene_reconstructor.weight.shape == (num_genes, num_pathways) # Verify values match (within tolerance) - assert torch.allclose(model.gene_reconstructor.weight, pathway_init.T) + # The interaction model now L1-normalizes the pathways for stability + # shape of pathway_init is (num_pathways, num_genes) + import torch.nn.functional as F + + # We must normalize the columns of pathway_init.T, which correspond to the rows of pathway_init + # Adding a small epsilon as done in interaction.py + normalized_pathway_init = pathway_init / ( + pathway_init.sum(dim=1, keepdim=True) + 1e-6 + ) + + assert torch.allclose( + model.gene_reconstructor.weight, normalized_pathway_init.T, atol=1e-5 + ) def test_checkpoint_save_load(): diff --git a/tests/test_losses.py b/tests/test_losses.py index e5cfd34..d22d989 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -447,9 +447,34 @@ def test_perfect_match_zero_aux(self, pathway_tensors): base = MaskedMSELoss() aux = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=1.0) - # Compute ground truth pathways + # Compute ground truth pathways matching the new AuxiliaryPathwayLoss logic with torch.no_grad(): - target_pathways = torch.matmul(targets, pw_matrix.T) + if targets.dim() == 2: + # Patch level: (B, G). Normalize across the batch dimension + means = targets.mean(dim=0, keepdim=True) + stds = targets.std(dim=0, keepdim=True).clamp(min=1e-6) + norm_genes = (targets - means) / stds + else: + # Whole slide: (B, N, G). Normalize across valid spatial positions N + valid_mask = ( + ~mask.unsqueeze(-1) + if mask is not None + else torch.ones_like(targets, dtype=torch.bool) + ) + valid_counts = valid_mask.sum(dim=1, keepdim=True).clamp(min=1.0) + means = (targets * valid_mask.float()).sum( + dim=1, keepdim=True + ) / valid_counts + diffs = (targets - means) * valid_mask.float() + vars = (diffs**2).sum(dim=1, keepdim=True) / (valid_counts - 1).clamp( + min=1.0 + ) + stds = torch.sqrt(vars).clamp(min=1e-6) + norm_genes = (diffs / stds) * valid_mask.float() + + target_pathways = torch.matmul(norm_genes, pw_matrix.T) + member_counts = pw_matrix.sum(dim=1, keepdim=True).T.clamp(min=1.0) + target_pathways = target_pathways / member_counts gene_loss = base(gene_preds, targets, mask=mask) # Use target_pathways as pathway_preds @@ -566,8 +591,15 @@ def test_hallmark_signal_detection(self): loss_fn = AuxiliaryPathwayLoss(pw_matrix, MaskedMSELoss(), lambda_pathway=1.0) loss_random = loss_fn(gene_preds, targets, pathway_preds=pw_preds_random) - # Case 2: Pathway preds perfectly match truth (which is targets @ matrix.T) - pw_truth = torch.matmul(targets, pw_matrix.T) + # Case 2: Pathway preds perfectly match truth + with torch.no_grad(): + means = targets.mean(dim=1, keepdim=True) + stds = targets.std(dim=1, keepdim=True).clamp(min=1e-6) + norm_genes = (targets - means) / stds + pw_truth = torch.matmul(norm_genes, pw_matrix.T) + member_counts = pw_matrix.sum(dim=1, keepdim=True).T.clamp(min=1.0) + pw_truth = pw_truth / member_counts + loss_perfect = loss_fn(gene_preds, targets, pathway_preds=pw_truth) # Case 3: Gene expression is specifically high for P0, and pw_preds are high for P0 diff --git a/tests/test_pathway_stability.py b/tests/test_pathway_stability.py new file mode 100644 index 0000000..b6cbc9a --- /dev/null +++ b/tests/test_pathway_stability.py @@ -0,0 +1,86 @@ +import pytest +import torch +from spatial_transcript_former.models.interaction import SpatialTranscriptFormer +from spatial_transcript_former.training.losses import ( + AuxiliaryPathwayLoss, + MaskedMSELoss, +) + + +def test_pathway_initialization_stability_and_gradients(): + """ + Verifies that initializing the model with a binary pathway matrix: + 1. Does not cause predictions to exponentially explode (numerical stability). + 2. Allows gradients to flow properly when using AuxiliaryPathwayLoss. + """ + torch.manual_seed(42) + num_pathways = 50 + num_genes = 100 + + # Create a synthetic MSigDB-style binary matrix + pathway_matrix = (torch.rand(num_pathways, num_genes) > 0.8).float() + # Ensure no empty pathways to avoid division by zero + pathway_matrix[:, 0] = 1.0 + + # Initialize model with pathway_init + model = SpatialTranscriptFormer( + num_genes=num_genes, + num_pathways=num_pathways, + pathway_init=pathway_matrix, + use_spatial_pe=False, + output_mode="counts", + pretrained=False, + ) + + # Dummy inputs + B, S, D = ( + 2, + 10, + 2048, + ) # Using D=2048 since backbone='resnet50' requires it natively, or provided features + feats = torch.randn(B, S, D, requires_grad=True) + coords = torch.randn(B, S, 2) + target_genes = torch.randn(B, S, num_genes).abs() + mask = torch.zeros(B, S, dtype=torch.bool) + + # Forward pass + # return_pathways=True is needed to get the intermediate pathway preds for Auxiliary loss + gene_preds, pathway_preds = model( + feats, rel_coords=coords, return_dense=True, return_pathways=True + ) + + # 1. Numerical Stability Check + # Without L1 normalization and removing temperature, predictions would explode. + # With the fix, Softplus should keep outputs reasonably small. + max_pred = gene_preds.max().item() + print(f"Max prediction value at initialization: {max_pred:.2f}") + assert ( + max_pred < 100.0 + ), f"Predictions exploded! Max value: {max_pred}. Check L1 normalization." + assert not torch.isnan(gene_preds).any(), "Found NaNs in initial predictions." + + # 2. Gradient Flow Check (Compatibility with Training) + loss_fn = AuxiliaryPathwayLoss(pathway_matrix, MaskedMSELoss(), lambda_pathway=1.0) + loss = loss_fn(gene_preds, target_genes, mask=mask, pathway_preds=pathway_preds) + + assert loss.isfinite(), "Loss is not finite." + + loss.backward() + + # Verify gradients reached the core transformer layers + target_layer_grad = model.fusion_engine.layers[0].linear1.weight.grad + assert target_layer_grad is not None, "Gradients did not reach the fusion engine." + assert target_layer_grad.norm() > 0, "Vanishing gradients in the fusion engine." + assert torch.isfinite( + target_layer_grad + ).all(), "Exploding/NaN gradients in fusion engine." + + # Verify gradients reached the final reconstructor layer + recon_grad = model.gene_reconstructor.weight.grad + assert recon_grad is not None, "Gradients did not reach the gene reconstructor." + assert recon_grad.norm() > 0, "Vanishing gradients in the gene reconstructor." + assert torch.isfinite( + recon_grad + ).all(), "Exploding/NaN gradients in gene reconstructor." + + print("Pathway initialization is fully stable and compatible with NN training.") diff --git a/tests/test_pathways.py b/tests/test_pathways.py index 958498e..f732fa2 100644 --- a/tests/test_pathways.py +++ b/tests/test_pathways.py @@ -138,12 +138,16 @@ class TestPathwayTruth: def test_consistent_across_calls(self, gene_list): """Ground truth from MSigDB membership should be identical across calls.""" from spatial_transcript_former.visualization import _compute_pathway_truth + from unittest.mock import MagicMock + + args = MagicMock() + args.sparsity_lambda = 0.0 np.random.seed(42) gene_truth = np.random.rand(200, len(gene_list)).astype(np.float32) - result1, names1 = _compute_pathway_truth(gene_truth, gene_list) - result2, names2 = _compute_pathway_truth(gene_truth, gene_list) + result1, names1 = _compute_pathway_truth(gene_truth, gene_list, args) + result2, names2 = _compute_pathway_truth(gene_truth, gene_list, args) np.testing.assert_array_equal(result1, result2) assert names1 == names2 @@ -151,10 +155,14 @@ def test_consistent_across_calls(self, gene_list): def test_output_shape(self, gene_list): """Pathway truth should be (N, P) where P=50 (Hallmarks default).""" from spatial_transcript_former.visualization import _compute_pathway_truth + from unittest.mock import MagicMock + + args = MagicMock() + args.sparsity_lambda = 0.0 N = 150 gene_truth = np.random.rand(N, len(gene_list)).astype(np.float32) - result, names = _compute_pathway_truth(gene_truth, gene_list) + result, names = _compute_pathway_truth(gene_truth, gene_list, args) assert result.shape == (N, 50) assert len(names) == 50 @@ -162,6 +170,10 @@ def test_output_shape(self, gene_list): def test_spatial_variation(self, gene_list): """Pathway truth should have spatial variation (non-zero std).""" from spatial_transcript_former.visualization import _compute_pathway_truth + from unittest.mock import MagicMock + + args = MagicMock() + args.sparsity_lambda = 0.0 # Create gene expression with spatial patterns N = 200 @@ -170,7 +182,7 @@ def test_spatial_variation(self, gene_list): gene_truth[:100, 0] += 5.0 gene_truth[100:, 1] += 5.0 - result, _ = _compute_pathway_truth(gene_truth, gene_list) + result, _ = _compute_pathway_truth(gene_truth, gene_list, args) # At least some pathways should have non-trivial spatial variation stds = np.std(result, axis=0) diff --git a/tests/test_spatial_interaction.py b/tests/test_spatial_interaction.py index b084ec4..d15a210 100644 --- a/tests/test_spatial_interaction.py +++ b/tests/test_spatial_interaction.py @@ -250,26 +250,6 @@ def test_interaction_mask_bits(): assert mask[2, 3] == False, "h2h interaction [2, 3] should be enabled" -def test_temperature_scaling(): - """Verify log_temperature actually scales the pathway scores.""" - model = SpatialTranscriptFormer(num_genes=10, token_dim=64) - features = torch.randn(1, 4, 2048) - coords = torch.randn(1, 4, 2) - - # Initial scores with default temp - scores1 = model(features, rel_coords=coords, return_pathways=True)[1] - - # Manually increase log_temperature significantly - with torch.no_grad(): - model.log_temperature.fill_(10.0) # Massive temp - - scores2 = model(features, rel_coords=coords, return_pathways=True)[1] - - # Scores should be different and typically more extreme - assert not torch.allclose(scores1, scores2) - assert scores2.abs().max() > scores1.abs().max() - - def test_return_attention_values(): """Validate attention weight extraction logic.""" model = SpatialTranscriptFormer( diff --git a/tests/test_visualization.py b/tests/test_visualization.py index c149232..4da4385 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -12,10 +12,19 @@ matplotlib.use("Agg") from spatial_transcript_former.predict import ( - CORE_PATHWAYS, plot_training_summary, ) +# Representative pathways to highlight in summary plots (local copy for testing) +CORE_PATHWAYS = [ + "APOPTOSIS", + "DNA_REPAIR", + "G2M_CHECKPOINT", + "MTORC1_SIGNALING", + "P53_PATHWAY", + "MYC_TARGETS_V1", +] + # --------------------------------------------------------------------------- # Fixtures # ---------------------------------------------------------------------------