From af004f812121619e43f9c580b069bdf17ff233f2 Mon Sep 17 00:00:00 2001 From: BenjaminIsaac0111 <12176376+BenjaminIsaac0111@users.noreply.github.com> Date: Tue, 9 Jun 2026 15:53:56 +0100 Subject: [PATCH] feat: add CCC, Huber, and CLIP losses with per-pathway validation metrics and SVG analysis - Implement CCCLoss, MaskedHuberLoss, and CLIPAlignmentLoss in training losses. - Update CompositeLoss to support configurable combinations of MSE/Huber, PCC/CCC, and optional CLIP regularisation for future work. - Enhance validation engine to compute, track, and log both PCC and CCC metrics, including slide-level and per-pathway breakdowns. - Add exploratory SVG (Spatially Variable Genes) analysis script and documentation. - Update argument parsing, presets, pathways computation, checkpointing, and visualization. - Implement comprehensive unit tests for new loss functions, pathway evaluations, and visualization utilities. --- README.md | 5 +- docs/PATHWAY_MAPPING.md | 21 +- docs/SVG_HEST_EXPLORATORY_ANALYSIS.md | 71 ++ scripts/analyze_svg.py | 611 ++++++++++++++++++ scripts/run_preset.py | 75 ++- src/spatial_transcript_former/checkpoint.py | 23 +- .../dashboard/callbacks.py | 10 +- .../models/interaction.py | 19 +- src/spatial_transcript_former/predict.py | 22 +- .../hest/compute_pathway_activities.py | 55 +- .../recipes/hest/dataset.py | 78 +-- src/spatial_transcript_former/train.py | 94 ++- .../training/arguments.py | 34 +- .../training/builder.py | 29 +- .../training/checkpoint.py | 46 +- .../training/engine.py | 252 +++++--- .../training/losses.py | 244 +++++-- .../training/trainer.py | 33 +- .../visualization.py | 74 +-- tests/data/test_pathways.py | 127 +++- tests/data/test_visualization.py | 96 +++ tests/models/test_interactions.py | 38 +- tests/recipes/hest/test_dataset.py | 4 +- tests/training/test_checkpoints.py | 12 +- tests/training/test_losses.py | 241 +++++-- 25 files changed, 1896 insertions(+), 418 deletions(-) create mode 100644 docs/SVG_HEST_EXPLORATORY_ANALYSIS.md create mode 100644 scripts/analyze_svg.py diff --git a/README.md b/README.md index 01d8fb2..9d9dc7c 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ For more details, see the **[Python API Reference](docs/API.md)**. - **Modular Architecture**: Decoupled backbones, interaction modules, and pathway output heads. - **Quad-Flow Interaction**: Configurable attention between Pathways and Histology patches (`p2p`, `p2h`, `h2p`, `h2h`). - **Pathway-Exclusive Prediction**: Directly predicts biological pathway activity scores (e.g., 50 MSigDB Hallmark pathways) — no intermediate gene reconstruction step. -- **Offline Pathway Targets**: Ground-truth pathway activities are pre-computed offline (`stf-compute-pathways`) from raw gene expression using QC → CP10k normalisation → z-score → mean pathway aggregation. This eliminates the circular auxiliary loss used in previous versions. +- **Offline Pathway Targets**: Ground-truth pathway activities are pre-computed offline (`stf-compute-pathways`) from raw gene expression using QC → CP10k normalisation → mean pathway aggregation. This eliminates the circular auxiliary loss used in previous versions. - **Spatial Pattern Coherence**: Optimised using a composite **MSE + PCC (Pearson Correlation) loss**. - **Foundation Model Ready**: Native support for **CTransPath**, **Phikon**, **Hibou**, **PLIP**, and **GigaPath**. @@ -87,7 +87,7 @@ stf-download --organ Breast --disease Cancer --tech Visium --local_dir hest_data ### 2. Pre-Compute Pathway Activity Targets -Before training, compute the offline pathway activity matrix for each sample. This step applies per-spot QC, CP10k normalisation, and z-scoring before aggregating gene expression into MSigDB Hallmark pathway scores. +Before training, compute the offline pathway activity matrix for each sample. This step applies per-spot QC and CP10k normalisation, then aggregates gene expression into MSigDB Hallmark pathway scores as the per-spot mean over each pathway's member genes. ```bash stf-compute-pathways --data-dir hest_data @@ -123,6 +123,7 @@ Visualization plots and spatial pathway activation maps will be saved to the `./ - **[Models & Architecture](docs/MODELS.md)**: Deep dive into the pathway-exclusive prediction architecture, quad-flow interaction logic, and network scaling. - **[Pathway Mapping](docs/PATHWAY_MAPPING.md)**: Offline pathway scoring methodology, QC pipeline, and MSigDB integration. +- **[SVG Exploratory Analysis](docs/SVG_HEST_EXPLORATORY_ANALYSIS.md)**: Detailed report on spatially variable pathway analysis across 95 HEST samples and data-driven target curation. - **[Data Structure](docs/DATA_FORMAT.md)**: Detailed breakdown of the HEST data structure on disk, metadata conventions, and preprocessing invariants. ## Development diff --git a/docs/PATHWAY_MAPPING.md b/docs/PATHWAY_MAPPING.md index 074cdeb..0c40cc7 100644 --- a/docs/PATHWAY_MAPPING.md +++ b/docs/PATHWAY_MAPPING.md @@ -16,9 +16,8 @@ For each `.h5ad` file, the following steps are applied in order: | :--- | :--- | :--- | | **1. QC Filtering** | Remove low-quality spots (min UMIs, min detected genes, max MT%) on **raw counts** | QC before normalisation prevents low-quality spots from distorting library-size estimates | | **2. CP10k Normalisation** | Scale each spot to 10,000 total counts, then apply `log1p` | Corrects for sequencing depth differences between spots | -| **3. Gene Z-Scoring** | Standardise each gene across surviving spots (mean=0, std=1) | Eliminates housekeeping gene dominance; every gene gets equal weight | -| **4. Pathway Aggregation** | For each pathway: take the mean z-score of its member genes present in the matrix | Produces a single, comparable activity score per pathway per spot | -| **5. Moran I** | Compute Moran's I for each gene on the raw counts | Computes spatial autocorrelation for each gene | +| **3. Pathway Aggregation** | For each pathway: take the mean log1p CP10k expression of its member genes present in the matrix | Slide-stationary by construction — no per-slide statistics enter the score, so the same biological state in two slides yields the same target value | +| **4. Moran I (diagnostic)** | Compute per-pathway Moran's I on the activity matrix | Records spatial autocorrelation alongside the targets; not used in the loss | Pathways with fewer than `--min-genes` (default: 5) detected members are filled with zeros. Samples with fewer than `--min-pathways` (default: 25) scorable pathways are excluded entirely. @@ -38,10 +37,12 @@ These defaults follow standard scRNA-seq / spatial transcriptomics QC practice t Each sample is saved as a compressed HDF5 file at `/pathway_activities/.h5`: ```text -activities float32 (n_spots, n_pathways) # z-scored pathway activity matrix -barcodes bytes (n_spots,) # spot barcode strings -pathway_names bytes (n_pathways,) # pathway name labels +activities float32 (n_spots, n_pathways) # mean log1p CP10k pathway score +barcodes bytes (n_spots,) # spot barcode strings +pathway_names bytes (n_pathways,) # pathway name labels +pathway_morans_i float32 (n_pathways,) # per-pathway Moran's I (diagnostic) attrs: + format_version int # on-disk schema version (current: 2) n_spots_before_qc int # total spots in raw h5ad n_spots_after_qc int # spots surviving QC qc_min_umis int @@ -50,6 +51,12 @@ attrs: n_scored_pathways int # pathways meeting the min_genes threshold ``` +> **Breaking change (v2):** `activities` is now the simple mean of log1p +> CP10k expression of pathway members — no per-slide z-score. Files written +> by older builds carry `format_version=1` (or no version attribute) and are +> rejected at load time. Re-run `stf-compute-pathways --overwrite` to +> regenerate. + These files are consumed at training time by `HEST_FeatureDataset` when `--pathway-targets-dir` is provided (which defaults to `/pathway_activities`). ### Usage @@ -100,7 +107,7 @@ The current design eliminates this entirely: | Aspect | Old (Auxiliary Loss) | New (Pre-computed Targets) | | :--- | :--- | :--- | | Target source | Computed in-flight from training labels | Computed once, offline, from raw expression | -| QC & normalisation | None | Per-spot QC → CP10k → z-score | +| QC & normalisation | None | Per-spot QC → CP10k → mean pathway aggregation | | Model output | Gene expression (via gene reconstructor) | Pathway activity scores directly | | Loss objective | `L_gene + λ · (1 - PCC(scores, pseudo-targets))` | `MSE + PCC` against pre-computed activities | | Interpretability | Indirect (pathway scores were internal and needed to be mapped back to pathways) | Direct (output *is* the pathway activity) | diff --git a/docs/SVG_HEST_EXPLORATORY_ANALYSIS.md b/docs/SVG_HEST_EXPLORATORY_ANALYSIS.md new file mode 100644 index 0000000..2b86a59 --- /dev/null +++ b/docs/SVG_HEST_EXPLORATORY_ANALYSIS.md @@ -0,0 +1,71 @@ +# SVG Exploratory Analysis & Pathway Curation Report + +This report summarizes the data-driven analysis of Spatially Variable Pathways across the HEST dataset (95 samples, 85 valid human samples after QC) conducted to refine the targets for model training. + +## 1. Methodology + +The analysis was performed using a standalone utility (`scripts/analyze_svg.py`) that: +1. **Strips common gene prefixes** (e.g., `GRCh38_`, `GRCm38_`) to ensure compatibility with MSigDB Hallmark gene sets. +2. **Computes pathway activities** using a sum-aggregation method (normalized to 10k target sum). +3. **Calculates Moran's I** for each of the 50 Hallmark pathways per sample. +4. **Aggregates statistics** (mean, median, std, etc.) across all valid human samples. +5. **Analyzes correlations** between spot-level pathway activities to understand redundancy. + +## 2. Global Spatial Autocorrelation Results + +The following plot shows the ranked mean Moran's I across 85 human samples for all 50 Hallmark pathways. + +![Global SVG Analysis](./assets/reports/svg_analysis_full.png) + +### Key Observations: +* **Widespread Spatial Structure**: All 50 pathways exhibit positive spatial autocorrelation (Mean Moran's I > 0.15). +* **High-Signal Pathways**: Top-ranked pathways include **MYC Targets V1** (0.665), **E2F Targets** (0.639), **G2M Checkpoint** (0.633), and **Oxidative Phosphorylation** (0.631). +* **Variance vs. Spatiality**: High expression variance does not always equate to high spatial coherence. Some pathways vary significantly between spots but lack a spatially organized pattern. + +--- + +## 3. CRC Pathway Curation + +Based on these results, the curated list of pathways for Colorectal Cancer (CRC) was validated. While some pathways exhibit lower spatial autocorrelation than others, all 14 selected hallmarks exceed a significance baseline of **Mean Moran's I > 0.20** and are therefore retained for training. + +| Status | Pathway | Mean Moran's I | % Samples > 0.05 | +| :--- | :--- | :--- | :--- | +| ✅ **Retained** | EPITHELIAL_MESENCHYMAL_TRANSITION | 0.602 | 98.8% | +| ✅ **Retained** | DNA_REPAIR | 0.554 | 91.8% | +| ✅ **Retained** | APOPTOSIS | 0.547 | 100.0% | +| ✅ **Retained** | P53_PATHWAY | 0.546 | 92.9% | +| ✅ **Retained** | HYPOXIA | 0.539 | 100.0% | +| ✅ **Retained** | APICAL_JUNCTION | 0.498 | 100.0% | +| ✅ **Retained** | INFLAMMATORY_RESPONSE | 0.487 | 100.0% | +| ✅ **Retained** | PI3K_AKT_MTOR_SIGNALING | 0.483 | 91.8% | +| ✅ **Retained** | KRAS_SIGNALING_UP | 0.469 | 98.8% | +| ✅ **Retained** | IL6_JAK_STAT3_SIGNALING | 0.408 | 98.8% | +| ✅ **Retained** | TGF_BETA_SIGNALING | 0.397 | 94.1% | +| ✅ **Retained** | ANGIOGENESIS | 0.339 | 94.1% | +| ✅ **Retained** | WNT_BETA_CATENIN_SIGNALING | 0.302 | 90.6% | +| ✅ **Retained** | KRAS_SIGNALING_DN | 0.250 | 95.3% | + +### Rationalization: +Although pathways like **WNT/β-catenin** and **KRAS_DN** have lower Moran's I scores (0.30 and 0.25 respectively) compared to **EMT** (0.60), they remain significantly above the background noise floor (~0.15). Their relative spatial uniformity likely reflects constitutive activation by driver mutations (e.g., APC mutations making WNT "on" globally), but the remaining spatial gradients are biologically critical for capturing tumor margins and stroma-epithelial interactions. + +--- + +## 4. Pathway Correlation & Redundancy + +To ensure the model is learning distinct biological signals, we analyzed the correlation between spot-level activities of the 14 CRC pathways. + +![Pathway Correlations](./assets/reports/pathway_correlations_full.png) + +### Correlation Insights: +* **Biological Axes**: Strong correlations exist between **Angiogenesis** and **EMT** (r=0.749), and between **TGF-β** and **Apoptosis** (r=0.668). These axes represent co-regulated spatial processes. +* **Distinct Signals**: Despite these correlations, each pathway provides a unique biological "view" of the tissue. Retaining the full set allows the model to learn complex regulatory relationships rather than just isolated spatial patterns. + +**Conclusion**: All 14 CRC pathways exhibit sufficient spatial structure and biological relevance to be included as training targets. This ensures the model learns a comprehensive representation of the CRC tissue microenvironment. + +--- + +## 5. Technical Improvements + +During this analysis, two critical fixes were implemented: +1. **Gene Prefix Stripping**: Fixed an issue where samples like `TENX175` had all-zero pathway scores because gene names were prefixed with `GRCh38_`. +2. **Sample Compatibility Check**: Added a check for Hallmark gene overlap to automatically skip mouse samples or low-density panels that cannot be accurately scored using human Hallmark sets. diff --git a/scripts/analyze_svg.py b/scripts/analyze_svg.py new file mode 100644 index 0000000..b228407 --- /dev/null +++ b/scripts/analyze_svg.py @@ -0,0 +1,611 @@ +""" +Exploratory analysis of Spatially Variable Genes/Pathways (SVGs). + +Reads .h5ad files from the HEST data directory, computes pathway activity +scores and Moran's I spatial autocorrelation, then produces a summary +report showing which pathways exhibit genuine spatial structure vs noise. + +This is a pure analysis tool — it does not modify any training data or +model configuration. Use the output to make informed decisions about +which pathways to include as training targets. + +Usage:: + + python scripts/analyze_svg.py --data-dir hest_data + python scripts/analyze_svg.py --data-dir hest_data --sample-ids TENX175 + python scripts/analyze_svg.py --data-dir hest_data --threshold 0.05 + python scripts/analyze_svg.py --data-dir hest_data --plot +""" + +import argparse +import os +import sys +import logging + +import numpy as np +import pandas as pd + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( + _load_expression, + _load_hallmark_sets, + _score_pathways, + _compute_pathway_morans_i, +) + +logger = logging.getLogger(__name__) + + +def analyze_sample(h5ad_path, pathway_dict, min_genes=5, k=6): + """Compute pathway activities and Moran's I for a single sample. + + Returns + ------- + sample_id : str + pathway_names : list of str + morans : np.ndarray, shape (n_pathways,) + n_spots : int + n_scored : int + """ + sample_id = os.path.basename(h5ad_path).replace(".h5ad", "") + + adata, n_before, n_after = _load_expression( + h5ad_path, + target_sum=10_000, + qc_min_umis=500, + qc_min_genes=200, + qc_max_mt=0.15, + ) + + if n_after == 0: + logger.warning(f"[{sample_id}] All spots filtered by QC. Skipping.") + return None + + from scipy.sparse import issparse + + expr = adata.X + if issparse(expr): + expr = expr.toarray() + expr = expr.astype(np.float32) + gene_names = list(adata.var_names) + + # Strip common prefixes (GRCh38_, GRCm38_) often found in HEST/10x data + stripped_names = [] + for g in gene_names: + if "_" in g and (g.startswith("GRCh38_") or g.startswith("GRCm38_")): + stripped_names.append(g.split("_", 1)[1]) + else: + stripped_names.append(g) + gene_names = stripped_names + + # Quick check for Hallmark coverage + all_hallmark_genes = set() + for genes in pathway_dict.values(): + all_hallmark_genes.update(genes) + + overlap = set(gene_names) & all_hallmark_genes + if len(overlap) < 100: + logger.warning( + f"[{sample_id}] Insufficient Hallmark gene overlap ({len(overlap)}). Skipping (likely non-human or low-density panel)." + ) + return None + + activities, all_pathways, n_scored = _score_pathways( + expr, gene_names, pathway_dict, min_genes=min_genes + ) + + # Get spatial coordinates + coords = ( + np.column_stack( + [adata.obs["array_row"].values, adata.obs["array_col"].values] + ).astype(np.float64) + if "array_row" in adata.obs.columns + else None + ) + if coords is None: + for key in ["spatial", "X_spatial"]: + if key in adata.obsm: + coords = np.array(adata.obsm[key], dtype=np.float64)[:, :2] + break + + if coords is None or len(coords) < k + 1: + logger.warning(f"[{sample_id}] No spatial coordinates. Skipping.") + return None + + morans = _compute_pathway_morans_i(activities, coords, k=k) + + # Also compute per-pathway variance (to detect zero-variance pathways) + variances = activities.var(axis=0) + + return { + "sample_id": sample_id, + "pathway_names": all_pathways, + "activities": activities, # (n_spots, n_pathways) + "morans": morans, + "variances": variances, + "n_spots": n_after, + "n_scored": n_scored, + } + + +def print_report(results, threshold=None, pathways_filter=None): + """Print a formatted SVG analysis report.""" + + # Build a DataFrame: rows = pathways, columns = samples + all_pathways = results[0]["pathway_names"] + n_pathways = len(all_pathways) + n_samples = len(results) + + morans_matrix = np.zeros((n_samples, n_pathways), dtype=np.float32) + var_matrix = np.zeros((n_samples, n_pathways), dtype=np.float32) + sample_ids = [] + + for i, r in enumerate(results): + morans_matrix[i] = r["morans"] + var_matrix[i] = r["variances"] + sample_ids.append(r["sample_id"]) + + # Summary statistics per pathway + df = pd.DataFrame( + { + "Pathway": all_pathways, + "Mean_I": morans_matrix.mean(axis=0), + "Median_I": np.median(morans_matrix, axis=0), + "Std_I": morans_matrix.std(axis=0), + "Min_I": morans_matrix.min(axis=0), + "Max_I": morans_matrix.max(axis=0), + "Mean_Var": var_matrix.mean(axis=0), + } + ) + + # Count how many samples each pathway is above threshold + if threshold is not None: + df["Samples_Above"] = (morans_matrix >= threshold).sum(axis=0) + df["Pct_Above"] = df["Samples_Above"] / n_samples * 100 + + # Sort by mean Moran's I descending + df = df.sort_values("Mean_I", ascending=False).reset_index(drop=True) + df.index += 1 # 1-indexed rank + + # Apply pathway filter if specified + if pathways_filter: + filter_set = set(pathways_filter) + df["In_Filter"] = df["Pathway"].isin(filter_set) + + # --- Print Header --- + print() + print("=" * 90) + print(f" SVG ANALYSIS REPORT") + print(f" Samples: {n_samples} | Pathways: {n_pathways}") + for r in results: + print( + f" {r['sample_id']}: {r['n_spots']} spots, {r['n_scored']} pathways scored" + ) + print("=" * 90) + + # --- Print Table --- + print() + shorten = lambda s, n=42: s[: n - 2] + ".." if len(s) > n else s + + header = f"{'Rank':>4} {'Pathway':<44} {'Mean':>6} {'Med':>6} {'Std':>6} {'Min':>6} {'Max':>6} {'Var':>8}" + if threshold is not None: + header += f" {'Above':>8}" + if pathways_filter: + header += f" {'CRC':>3}" + print(header) + print("─" * len(header)) + + for rank, row in df.iterrows(): + line = ( + f"{rank:>4} {shorten(row['Pathway']):<44} " + f"{row['Mean_I']:>6.3f} {row['Median_I']:>6.3f} " + f"{row['Std_I']:>6.3f} {row['Min_I']:>6.3f} " + f"{row['Max_I']:>6.3f} {row['Mean_Var']:>8.4f}" + ) + if threshold is not None: + line += f" {int(row['Samples_Above']):>3}/{n_samples}" + if pathways_filter: + marker = " ✓" if row.get("In_Filter", False) else " " + line += f" {marker}" + print(line) + + # --- Threshold Summary --- + if threshold is not None: + print() + svg_df = df[df["Mean_I"] >= threshold] + non_svg = df[df["Mean_I"] < threshold] + + print(f"Pathways with mean Moran's I ≥ {threshold}: {len(svg_df)} / {len(df)}") + print(f"Pathways below threshold (noise): {len(non_svg)} / {len(df)}") + print() + + if len(svg_df) > 0: + print("# Recommended SVG pathway list (copy-paste into run_preset.py):") + print("SVG_PATHWAYS = [") + for _, row in svg_df.iterrows(): + print(f' "{row["Pathway"]}",') + print("]") + + # --- CRC Filter Summary --- + if pathways_filter: + print() + crc_df = df[df["In_Filter"]] + print(f"CRC pathways in filter: {len(crc_df)} / {len(pathways_filter)}") + if threshold is not None: + crc_svg = crc_df[crc_df["Mean_I"] >= threshold] + crc_noise = crc_df[crc_df["Mean_I"] < threshold] + print(f" Above threshold: {len(crc_svg)}") + print(f" Below threshold (noise): {len(crc_noise)}") + if len(crc_noise) > 0: + print(f" Noisy CRC pathways:") + for _, row in crc_noise.iterrows(): + print(f" - {row['Pathway']} (Mean I = {row['Mean_I']:.3f})") + + print() + return df + + +def plot_svg_analysis(df, results, output_dir=None): + """Generate exploratory plots of SVG statistics.""" + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed. Skipping plots.") + return + + if output_dir is None: + output_dir = "." + os.makedirs(output_dir, exist_ok=True) + + fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + fig.suptitle("Spatially Variable Pathway Analysis", fontsize=14, fontweight="bold") + + # 1. Ranked bar chart of mean Moran's I + ax = axes[0, 0] + colors = ["#2ecc71" if v >= 0.05 else "#e74c3c" for v in df["Mean_I"]] + ax.barh(range(len(df)), df["Mean_I"].values, color=colors, height=0.7) + ax.set_yticks(range(len(df))) + ax.set_yticklabels([p[:30] for p in df["Pathway"]], fontsize=6) + ax.set_xlabel("Mean Moran's I") + ax.set_title("Pathway Spatial Autocorrelation (ranked)") + ax.axvline(x=0.05, color="gray", linestyle="--", alpha=0.7, label="Threshold=0.05") + ax.invert_yaxis() + ax.legend(fontsize=8) + + # 2. Histogram of all Moran's I values + ax = axes[0, 1] + all_morans = np.concatenate([r["morans"] for r in results]) + ax.hist(all_morans, bins=50, color="#3498db", edgecolor="white", alpha=0.8) + ax.axvline(x=0.05, color="red", linestyle="--", label="Threshold=0.05") + ax.set_xlabel("Moran's I") + ax.set_ylabel("Count") + ax.set_title("Distribution of Moran's I (all pathways, all samples)") + ax.legend() + + # 3. Moran's I vs Expression Variance + ax = axes[1, 0] + ax.scatter( + df["Mean_Var"], df["Mean_I"], c=colors, s=40, alpha=0.8, edgecolors="white" + ) + ax.set_xlabel("Mean Expression Variance") + ax.set_ylabel("Mean Moran's I") + ax.set_title("Spatial Structure vs Expression Variability") + for _, row in df.head(5).iterrows(): + ax.annotate( + row["Pathway"][:25], + (row["Mean_Var"], row["Mean_I"]), + fontsize=6, + alpha=0.7, + ) + + # 4. Per-sample heatmap (if multiple samples) + ax = axes[1, 1] + if len(results) > 1: + morans_matrix = np.array([r["morans"] for r in results]) + # Sort columns by mean Moran's I (same as df order) + sort_idx = np.argsort( + [np.mean(morans_matrix[:, i]) for i in range(morans_matrix.shape[1])] + )[::-1] + im = ax.imshow(morans_matrix[:, sort_idx], aspect="auto", cmap="RdYlGn") + ax.set_xlabel("Pathway (ranked)") + ax.set_ylabel("Sample") + ax.set_yticks(range(len(results))) + ax.set_yticklabels([r["sample_id"][:12] for r in results], fontsize=7) + ax.set_title("Moran's I Heatmap (samples × pathways)") + plt.colorbar(im, ax=ax, label="Moran's I") + else: + # Single sample: show bar chart per pathway + r = results[0] + sort_idx = np.argsort(r["morans"])[::-1] + pathway_names = [r["pathway_names"][i][:25] for i in sort_idx] + values = r["morans"][sort_idx] + bar_colors = ["#2ecc71" if v >= 0.05 else "#e74c3c" for v in values] + ax.bar(range(len(values)), values, color=bar_colors, width=0.8) + ax.set_xticks(range(len(values))) + ax.set_xticklabels(pathway_names, rotation=90, fontsize=5) + ax.set_ylabel("Moran's I") + ax.set_title(f"Per-Pathway Moran's I ({r['sample_id']})") + ax.axhline(y=0.05, color="gray", linestyle="--", alpha=0.7) + + plt.tight_layout() + plot_path = os.path.join(output_dir, "svg_analysis.png") + plt.savefig(plot_path, dpi=150, bbox_inches="tight") + print(f"Plot saved to: {plot_path}") + plt.close() + + +def plot_correlation_heatmap(results, output_dir=None, crc_pathways=None): + """Generate a pathway-pathway correlation heatmap from pooled spot activities. + + Concatenates spot-level activities across all samples, computes pairwise + Pearson correlations between pathways, and plots a clustered heatmap. + CRC pathways are highlighted and retained/dropped status is annotated. + """ + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.patches import Patch + from scipy.cluster.hierarchy import linkage, dendrogram, leaves_list + from scipy.spatial.distance import squareform + except ImportError: + print("matplotlib/scipy not installed. Skipping correlation plot.") + return + + if output_dir is None: + output_dir = "." + os.makedirs(output_dir, exist_ok=True) + + # Pool spot-level activities across all samples + pathway_names = results[0]["pathway_names"] + pooled = np.concatenate([r["activities"] for r in results], axis=0) + n_total_spots = pooled.shape[0] + + # Compute Pearson correlation matrix + # (columns = pathways, so corr across rows = spots) + corr = np.corrcoef(pooled.T) # (n_pathways, n_pathways) + np.fill_diagonal(corr, 1.0) + corr = np.nan_to_num(corr, nan=0.0) # zero-variance pathways + + # Hierarchical clustering for ordering + dist = 1 - corr + dist = np.clip(dist, 0, 2) # ensure valid distances + condensed = squareform(dist, checks=False) + Z = linkage(condensed, method="ward") + order = leaves_list(Z) + + corr_ordered = corr[np.ix_(order, order)] + names_ordered = [pathway_names[i] for i in order] + + # Determine CRC status for each pathway + crc_retained = set() + crc_dropped = set() + if crc_pathways: + # The dropped CRC pathways + dropped = { + "HALLMARK_TGF_BETA_SIGNALING", + "HALLMARK_ANGIOGENESIS", + "HALLMARK_WNT_BETA_CATENIN_SIGNALING", + "HALLMARK_KRAS_SIGNALING_DN", + } + crc_retained = set(crc_pathways) - dropped + crc_dropped = set(crc_pathways) & dropped + + # Shorten pathway names for labels + def short_name(name): + return name.replace("HALLMARK_", "").replace("_", " ").title()[:30] + + # Color-code labels + label_colors = [] + for name in names_ordered: + if name in crc_retained: + label_colors.append("#2ecc71") # green = retained CRC + elif name in crc_dropped: + label_colors.append("#e74c3c") # red = dropped CRC + else: + label_colors.append("#666666") # grey = non-CRC + + # --- Plot --- + fig, ax = plt.subplots(figsize=(16, 14)) + im = ax.imshow(corr_ordered, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto") + + ax.set_xticks(range(len(names_ordered))) + ax.set_yticks(range(len(names_ordered))) + ax.set_xticklabels([short_name(n) for n in names_ordered], rotation=90, fontsize=7) + ax.set_yticklabels([short_name(n) for n in names_ordered], fontsize=7) + + # Colour the tick labels + for i, (xtick, ytick) in enumerate(zip(ax.get_xticklabels(), ax.get_yticklabels())): + xtick.set_color(label_colors[i]) + ytick.set_color(label_colors[i]) + + ax.set_title( + f"Pathway Activity Correlations (Pearson, {n_total_spots:,} pooled spots from {len(results)} samples)", + fontsize=12, + fontweight="bold", + pad=15, + ) + + plt.colorbar(im, ax=ax, label="Pearson r", shrink=0.8) + + # Legend + legend_elements = [ + Patch(facecolor="#2ecc71", label="CRC — Retained (spatial)"), + Patch(facecolor="#e74c3c", label="CRC — Dropped (non-spatial)"), + Patch(facecolor="#666666", label="Non-CRC pathway"), + ] + ax.legend(handles=legend_elements, loc="lower left", fontsize=9) + + plt.tight_layout() + plot_path = os.path.join(output_dir, "pathway_correlations.png") + plt.savefig(plot_path, dpi=150, bbox_inches="tight") + print(f"Correlation plot saved to: {plot_path}") + plt.close() + + # --- Print correlation summary for dropped CRC pathways --- + if crc_dropped: + print("\nCorrelations between DROPPED CRC pathways and RETAINED CRC pathways:") + print(f"{'Dropped Pathway':<44} {'Retained Pathway':<44} {'Pearson r':>9}") + print("─" * 100) + for dropped_name in sorted(crc_dropped): + if dropped_name not in pathway_names: + continue + di = pathway_names.index(dropped_name) + pairs = [] + for retained_name in sorted(crc_retained): + if retained_name not in pathway_names: + continue + ri = pathway_names.index(retained_name) + r_val = corr[di, ri] + pairs.append((retained_name, r_val)) + # Sort by absolute correlation + pairs.sort(key=lambda x: abs(x[1]), reverse=True) + for retained_name, r_val in pairs[:5]: # top 5 + short_d = dropped_name.replace("HALLMARK_", "") + short_r = retained_name.replace("HALLMARK_", "") + print(f" {short_d:<42} {short_r:<42} {r_val:>+9.3f}") + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Exploratory analysis of Spatially Variable Pathways (SVGs)." + ) + parser.add_argument( + "--data-dir", + type=str, + required=True, + help="Root HEST data directory (contains st/ subdirectory)", + ) + parser.add_argument( + "--sample-ids", + nargs="+", + default=None, + help="Specific sample IDs to analyze (default: all .h5ad in st/)", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.05, + help="Moran's I threshold to classify SVGs (default: 0.05)", + ) + parser.add_argument( + "--crc", + action="store_true", + help="Highlight CRC-specific pathways in the report", + ) + parser.add_argument( + "--plot", + action="store_true", + help="Generate exploratory plots (saved to --output-dir)", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory for plot output (default: /svg_analysis)", + ) + parser.add_argument( + "--k", + type=int, + default=6, + help="Number of spatial neighbours for Moran's I (default: 6)", + ) + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + ) + + st_dir = os.path.join(args.data_dir, "st") + if not os.path.isdir(st_dir): + print(f"Error: Could not find st/ directory in {args.data_dir}") + sys.exit(1) + + # Discover samples + if args.sample_ids: + sample_ids = args.sample_ids + else: + sample_ids = [f[:-5] for f in os.listdir(st_dir) if f.endswith(".h5ad")] + sample_ids.sort() + + if not sample_ids: + print(f"No .h5ad files found in {st_dir}") + sys.exit(1) + + print(f"Loading {len(sample_ids)} sample(s)...") + pathway_dict = _load_hallmark_sets(cache_dir=os.path.join(args.data_dir, ".cache")) + + results = [] + for sample_id in sample_ids: + h5ad_path = os.path.join(st_dir, f"{sample_id}.h5ad") + if not os.path.exists(h5ad_path): + logger.warning(f"Missing: {h5ad_path}") + continue + + print(f" Analyzing {sample_id}...", end=" ", flush=True) + result = analyze_sample(h5ad_path, pathway_dict, k=args.k) + if result is not None: + results.append(result) + print(f"OK ({result['n_spots']} spots, {result['n_scored']} pathways)") + else: + print("SKIPPED") + + if not results: + print("No samples could be analyzed.") + sys.exit(1) + + # CRC pathway list for filtering + crc_pathways = None + if args.crc: + crc_pathways = [ + "HALLMARK_WNT_BETA_CATENIN_SIGNALING", + "HALLMARK_TGF_BETA_SIGNALING", + "HALLMARK_KRAS_SIGNALING_UP", + "HALLMARK_KRAS_SIGNALING_DN", + "HALLMARK_PI3K_AKT_MTOR_SIGNALING", + "HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION", + "HALLMARK_ANGIOGENESIS", + "HALLMARK_APICAL_JUNCTION", + "HALLMARK_INFLAMMATORY_RESPONSE", + "HALLMARK_IL6_JAK_STAT3_SIGNALING", + "HALLMARK_APOPTOSIS", + "HALLMARK_P53_PATHWAY", + "HALLMARK_DNA_REPAIR", + "HALLMARK_HYPOXIA", + ] + + df = print_report(results, threshold=args.threshold, pathways_filter=crc_pathways) + + # Save to CSV file for easy viewing + output_dir = args.output_dir or os.path.join(args.data_dir, "svg_analysis") + os.makedirs(output_dir, exist_ok=True) + report_path = os.path.join(output_dir, "svg_report.csv") + try: + df.to_csv(report_path, index=True) + print(f"Report saved to: {report_path}") + except PermissionError: + # File might be open in another app + from datetime import datetime + + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + alt_path = os.path.join(output_dir, f"svg_report_{ts}.csv") + df.to_csv(alt_path, index=True) + print(f"Report saved to: {alt_path} (original was locked)") + + if args.plot: + output_dir = args.output_dir or os.path.join(args.data_dir, "svg_analysis") + plot_svg_analysis(df, results, output_dir=output_dir) + plot_correlation_heatmap( + results, output_dir=output_dir, crc_pathways=crc_pathways + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_preset.py b/scripts/run_preset.py index 1ddf630..c4f958e 100644 --- a/scripts/run_preset.py +++ b/scripts/run_preset.py @@ -3,7 +3,9 @@ from spatial_transcript_former.config import get_config -# Curated list of MSigDB Hallmarks with strong evidence of involvement in Colorectal/Bowel Cancer +# Curated list of MSigDB Hallmarks with strong evidence of involvement in Colorectal/Bowel Cancer. +# All 14 pathways have been validated to exhibit significant spatial autocorrelation +# (Mean Moran's I > 0.20) across 85 HEST samples. CRC_PATHWAYS = [ "HALLMARK_WNT_BETA_CATENIN_SIGNALING", "HALLMARK_TGF_BETA_SIGNALING", @@ -70,11 +72,72 @@ def make_stf_params(n_layers: int, token_dim: int, n_heads: int, batch_size: int "stf_small": make_stf_params(n_layers=4, token_dim=384, n_heads=8, batch_size=8), "stf_medium": make_stf_params(n_layers=6, token_dim=512, n_heads=8, batch_size=8), "stf_large": make_stf_params(n_layers=12, token_dim=768, n_heads=12, batch_size=8), - # --- Biologically-Prioritized Variants (e.g. Colorectal Cancer) --- - "stf_crc_tiny": {**make_stf_params(2, 256, 4, 8), "pathways": CRC_PATHWAYS}, - "stf_crc_small": {**make_stf_params(4, 384, 8, 8), "pathways": CRC_PATHWAYS}, - "stf_crc_medium": {**make_stf_params(6, 512, 8, 8), "pathways": CRC_PATHWAYS}, - "stf_crc_large": {**make_stf_params(12, 768, 12, 8), "pathways": CRC_PATHWAYS}, + # --- Biologically-Prioritized CRC Variants: MSE + CCC --- + # CCC penalises mean/variance offset that PCC ignores — better regression agreement. + "stf_crc_tiny": { + **make_stf_params(2, 256, 4, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_ccc", + }, + "stf_crc_small": { + **make_stf_params(4, 384, 8, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_ccc", + }, + "stf_crc_medium": { + **make_stf_params(6, 512, 8, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_ccc", + }, + "stf_crc_large": { + **make_stf_params(12, 768, 12, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_ccc", + }, + # --- Biologically-Prioritized CRC Variants: MSE + CCC + CLIP --- + # Adds CLIP alignment to prevent variance collapse (all predictions collapsing to mean). + "stf_crc_tiny_clip": { + **make_stf_params(2, 256, 4, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_ccc_clip", + }, + "stf_crc_small_clip": { + **make_stf_params(4, 384, 8, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_ccc_clip", + }, + "stf_crc_medium_clip": { + **make_stf_params(6, 512, 8, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_ccc_clip", + }, + "stf_crc_large_clip": { + **make_stf_params(12, 768, 12, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_ccc_clip", + }, + # --- Biologically-Prioritized CRC Variants: Huber + CCC --- + # Huber replaces MSE for robustness against outlier pathway activity values. + "stf_crc_tiny_huber": { + **make_stf_params(2, 256, 4, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_huber", + }, + "stf_crc_small_huber": { + **make_stf_params(4, 384, 8, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_huber", + }, + "stf_crc_medium_huber": { + **make_stf_params(6, 512, 8, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_huber", + }, + "stf_crc_large_huber": { + **make_stf_params(12, 768, 12, 8), + "pathways": CRC_PATHWAYS, + "loss": "mse_huber", + }, } diff --git a/src/spatial_transcript_former/checkpoint.py b/src/spatial_transcript_former/checkpoint.py index d7e48ce..75a78f5 100644 --- a/src/spatial_transcript_former/checkpoint.py +++ b/src/spatial_transcript_former/checkpoint.py @@ -102,8 +102,14 @@ def save_pretrained( """ os.makedirs(save_dir, exist_ok=True) - # 1. Config + # 1. Config (stamps the pathway-target format version so inference + # callers can detect mismatched preprocessing pipelines). + from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( + PATHWAY_FILE_VERSION, + ) + config = _model_config(model) + config["pathway_format_version"] = PATHWAY_FILE_VERSION with open(os.path.join(save_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) @@ -169,6 +175,21 @@ def load_pretrained( # Don't load pretrained backbone weights — we're loading our own config["pretrained"] = False + # Strip metadata fields that aren't constructor arguments. Validate the + # pathway-target format version so old checkpoints don't silently load + # against incompatible target semantics. + from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( + PATHWAY_FILE_VERSION, + ) + + saved_pfv = config.pop("pathway_format_version", None) + if saved_pfv is not None and int(saved_pfv) != PATHWAY_FILE_VERSION: + raise ValueError( + f"Checkpoint at {checkpoint_dir!r} was trained against pathway " + f"targets of format_version={saved_pfv}, but this build expects " + f"{PATHWAY_FILE_VERSION}. Re-train against current targets." + ) + # 2. Instantiate model = SpatialTranscriptFormer(**config) diff --git a/src/spatial_transcript_former/dashboard/callbacks.py b/src/spatial_transcript_former/dashboard/callbacks.py index 2f5cb9b..62b9b14 100644 --- a/src/spatial_transcript_former/dashboard/callbacks.py +++ b/src/spatial_transcript_former/dashboard/callbacks.py @@ -293,7 +293,7 @@ def update_metrics(n, smoothing_window, selected_runs): ) # Correlation / Errors - corr_cols = ["val_pcc", "val_mae"] + corr_cols = ["val_pcc", "val_ccc", "val_mae"] pcc_fig = go.Figure(data=_make_traces(data_dict, corr_cols, smoothing_window)) pcc_fig.update_layout( title="Validation Metrics", @@ -395,6 +395,14 @@ def update_metrics(n, smoothing_window, selected_runs): f"{run_lbl}Epoch {last_epoch}", ) ) + if "val_ccc" in df.columns: + kpi_elements.append( + create_kpi_card( + "Val CCC", + f"{last_row['val_ccc']:.4f}", + f"{run_lbl}Epoch {last_epoch}", + ) + ) if "lr" in df.columns: kpi_elements.append( create_kpi_card("Learning Rate", f"{last_row['lr']:.2e}") diff --git a/src/spatial_transcript_former/models/interaction.py b/src/spatial_transcript_former/models/interaction.py index d6bc3e6..03ccbd9 100644 --- a/src/spatial_transcript_former/models/interaction.py +++ b/src/spatial_transcript_former/models/interaction.py @@ -17,9 +17,10 @@ class LearnedSpatialEncoder(nn.Module): """Encodes 2D spatial coordinates via a small learned MLP. - Unlike sinusoidal PE, this produces smooth, non-periodic embeddings - that vary gradually across the tissue. Coordinates are normalised to - [-1, 1] per-batch before encoding for scale invariance. + Inputs are expected to already be slide-stationary (centred and + standardised at the dataset level), so this module is permutation- + invariant per spot — the embedding for a spot depends only on its + own coordinates, not on what other spots are in the batch. """ def __init__(self, dim): @@ -30,23 +31,15 @@ def __init__(self, dim): nn.Linear(dim, dim), ) - def _normalize_coords(self, coords): - """Normalize coordinates to [-1, 1] range per batch.""" - # Centre at zero - coords = coords - coords.mean(dim=1, keepdim=True) - # Scale to [-1, 1] - scale = coords.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1.0) - return coords / scale - def forward(self, rel_coords): """ Args: - rel_coords (torch.Tensor): (B, S, 2) spatial coordinates. + rel_coords (torch.Tensor): (B, S, 2) slide-stationary coordinates. Returns: torch.Tensor: (B, S, dim) positional embeddings. """ - return self.mlp(self._normalize_coords(rel_coords)) + return self.mlp(rel_coords) VALID_INTERACTIONS = {"p2p", "p2h", "h2p", "h2h"} diff --git a/src/spatial_transcript_former/predict.py b/src/spatial_transcript_former/predict.py index 280385f..0cbf1db 100644 --- a/src/spatial_transcript_former/predict.py +++ b/src/spatial_transcript_former/predict.py @@ -217,7 +217,13 @@ def predict_wsi( ) features = features.to(self.device) - coords = coords.to(self.device) + coords = coords.to(self.device).float() + + # Match the dataset-side slide-stationary normalisation so the model + # sees coordinates on the same scale it was trained on. + center = coords.mean(dim=1, keepdim=True) + scale = coords.std(dim=1).max(dim=-1, keepdim=True).values.clamp(min=1.0) + coords = (coords - center) / scale.unsqueeze(1) with torch.amp.autocast("cuda", enabled=self.use_amp): result = self.model( @@ -382,8 +388,8 @@ def plot_training_summary( Args: coords: (N, 2) spatial coordinates. - pathway_pred: (N, P) predicted pathway activations (spatial z-score). - pathway_truth: (N, P) ground truth pathway activations (spatial z-score). + pathway_pred: (N, P) predicted pathway activations (mean log1p CP10k). + pathway_truth: (N, P) ground truth pathway activations (mean log1p CP10k). pathway_names: List of P pathway names. sample_id: Identifier for the plot title. histology_img: Optional RGB image for background. @@ -437,7 +443,7 @@ def plot_training_summary( fig.patch.set_facecolor("#0f172a") plt.suptitle( - f"Pathway Validation (Spatial Z-Score): {sample_id}", + f"Pathway Validation: {sample_id}", fontsize=18, color="white", fontweight="bold", @@ -446,11 +452,13 @@ def plot_training_summary( for i, idx in enumerate(selected_indices): name = plot_names[i] - # Pred and Truth for this pathway (z-score scale) + # Pred and Truth for this pathway (mean log1p CP10k units) p = pathway_pred[:, idx] t = pathway_truth[:, idx] - # Vmin/Vmax for shared z-score scale + # Shared vmin/vmax across truth and prediction so both panels are + # directly comparable. Per-pathway range, since absolute scales + # differ across pathways under raw-mean targets. vmin = min(p.min(), t.min()) vmax = max(p.max(), t.max()) @@ -506,7 +514,7 @@ def plot_training_summary( aspect=40, pad=0.05, ) - cbar.set_label("Relative Expression (Spatial Z-Score)", fontsize=14, labelpad=10) + cbar.set_label("Pathway score (mean log1p CP10k)", fontsize=14, labelpad=10) cbar.ax.tick_params(labelsize=12) plt.savefig(save_path, dpi=200, bbox_inches="tight", facecolor="#0f172a") diff --git a/src/spatial_transcript_former/recipes/hest/compute_pathway_activities.py b/src/spatial_transcript_former/recipes/hest/compute_pathway_activities.py index 539d20e..e464784 100644 --- a/src/spatial_transcript_former/recipes/hest/compute_pathway_activities.py +++ b/src/spatial_transcript_former/recipes/hest/compute_pathway_activities.py @@ -5,19 +5,23 @@ 1. Loads the raw gene expression matrix (spots x genes) 2. Applies per-spot QC (min UMIs, min genes, max MT%) on raw counts 3. Applies CP10k normalisation + log1p to surviving spots - 4. Z-scores each gene across spots, then computes per-pathway mean z-score + 4. Computes per-pathway scores as the mean log1p CP10k expression of + member genes (no per-slide normalisation — targets are slide-stationary) 5. Saves the resulting activity matrix to /pathway_activities/.h5 Pathway scores are computed from MSigDB Hallmark gene sets (50 pathways). -For each pathway, the score per spot is the mean z-scored expression of -member genes present in the expression matrix. +The score for spot s and pathway p is the simple per-spot mean of the log1p +CP10k expression across the pathway's member genes that are present in the +sample. CP10k handles depth normalisation; no per-slide statistics enter the +score, so the same biological state in two slides yields the same target. Non-human samples are auto-skipped via HEST metadata. Samples with fewer than ``--min-pathways`` scored pathways are excluded. The saved files are consumed at training time by HEST_FeatureDataset when -``pathway_targets_dir`` is provided. +``pathway_targets_dir`` is provided. Files carry a ``format_version`` +attribute; loaders refuse mismatched versions and ask for a recompute. Usage:: @@ -47,6 +51,13 @@ logger = logging.getLogger(__name__) +# File-format version stamped into each pathway-activities .h5 file. +# Bumped whenever the on-disk semantics of `activities` change. +# v1: per-slide z-scored pathway-mean (deprecated — slide-relative drift). +# v2: plain mean of log1p CP10k expression of pathway members. +PATHWAY_FILE_VERSION = 2 + + def _load_expression( h5ad_path: str, target_sum: int = 10_000, @@ -145,7 +156,13 @@ def _load_hallmark_sets(cache_dir: str = ".cache"): def _score_pathways(expr_matrix, gene_names, pathway_dict, min_genes=5): - """Score pathway activities via mean z-scored expression of member genes. + """Score pathway activities as the mean log1p CP10k expression of member genes. + + The score for spot s and pathway p is the simple per-spot mean across the + pathway's member genes that are present in ``gene_names``. This is depth- + normalised (via the prior CP10k step) and slide-stationary by construction: + no per-slide statistics enter the score, so the same biological state in + two different slides yields the same target value. Parameters ---------- @@ -172,12 +189,6 @@ def _score_pathways(expr_matrix, gene_names, pathway_dict, min_genes=5): gene_to_idx = {g: i for i, g in enumerate(gene_names)} n_spots = expr_matrix.shape[0] - # Z-score each gene across spots (zero-variance genes get z=0) - means = expr_matrix.mean(axis=0) - stds = expr_matrix.std(axis=0) - stds[stds == 0] = 1.0 # avoid division by zero - z_matrix = (expr_matrix - means) / stds - all_pathways = list(pathway_dict.keys()) activities = np.zeros((n_spots, len(all_pathways)), dtype=np.float32) n_scored = 0 @@ -186,7 +197,7 @@ def _score_pathways(expr_matrix, gene_names, pathway_dict, min_genes=5): col_indices = [gene_to_idx[g] for g in pw_genes if g in gene_to_idx] if len(col_indices) < min_genes: continue - activities[:, i] = z_matrix[:, col_indices].mean(axis=1) + activities[:, i] = expr_matrix[:, col_indices].mean(axis=1) n_scored += 1 return activities, all_pathways, n_scored @@ -370,6 +381,8 @@ def compute_pathway_activities_for_sample( f.create_dataset("pathway_names", data=pathway_names) if pathway_morans is not None: f.create_dataset("pathway_morans_i", data=pathway_morans) + # File-format version — bumped when the semantics of `activities` change. + f.attrs["format_version"] = PATHWAY_FILE_VERSION # QC metadata for downstream auditing f.attrs["n_spots_before_qc"] = n_before f.attrs["n_spots_after_qc"] = n_after @@ -408,10 +421,24 @@ def load_pathway_activities( valid_mask : np.ndarray, shape (N_barcodes,), bool True for barcodes that were found in the activity file. pathway_morans_i : np.ndarray or None, shape (P,), float32 - Per-pathway Moran's I weights. ``None`` for older files that - were computed before this field was added. + Per-pathway Moran's I diagnostic. ``None`` if the field is absent. + + Raises + ------ + ValueError + If the file's ``format_version`` attribute is missing or does not + match :data:`PATHWAY_FILE_VERSION`. Re-run ``stf-compute-pathways + --overwrite`` to regenerate the file with the current schema. """ with h5py.File(h5_path, "r") as f: + version = f.attrs.get("format_version", None) + if version is None or int(version) != PATHWAY_FILE_VERSION: + raise ValueError( + f"Pathway file {h5_path!r} has format_version=" + f"{version!r}, but this build expects " + f"{PATHWAY_FILE_VERSION}. Re-run " + "`stf-compute-pathways --overwrite` to regenerate." + ) stored_acts = f["activities"][:] # (M, P) stored_barcodes = f["barcodes"][:] # bytes array pathway_names = [ diff --git a/src/spatial_transcript_former/recipes/hest/dataset.py b/src/spatial_transcript_former/recipes/hest/dataset.py index d1e2e23..ba5c917 100644 --- a/src/spatial_transcript_former/recipes/hest/dataset.py +++ b/src/spatial_transcript_former/recipes/hest/dataset.py @@ -414,24 +414,28 @@ def _load_data(self): mask_bool = np.array(mask, dtype=bool) self.features = features[mask_bool] # (N_valid, D) coords_valid = coords[mask_bool].numpy() - self.coords = torch.from_numpy( + grid_coords = torch.from_numpy( normalize_coordinates(coords_valid) - ) # (N_valid, 2) + ).float() # (N_valid, 2) — integer-grid units + + # Slide-stationary normalisation: centre at slide centroid, divide by + # per-slide scale. PE consumers therefore see coordinates that depend + # only on this slide's geometry, not on batch composition. + center = grid_coords.mean(dim=0, keepdim=True) + scale = grid_coords.std(dim=0).max().clamp(min=1.0) + self.coords = (grid_coords - center) / scale # (N_valid, 2) self.genes = None # gene_matrix removed self.kdtree = KDTree(self.coords.numpy()) # Load pathway activity targets if a directory is provided self.pathway_activities = None - self.pathway_morans_i = None if self.pathway_targets_dir is not None: sample_id = os.path.splitext(os.path.basename(self.feature_path))[0] h5_path = os.path.join(self.pathway_targets_dir, f"{sample_id}.h5") if os.path.exists(h5_path): from .compute_pathway_activities import load_pathway_activities - acts, pw_names, _, pw_morans = load_pathway_activities( - h5_path, list(barcodes) - ) + acts, pw_names, _, _ = load_pathway_activities(h5_path, list(barcodes)) if self.target_pathway_names is not None: # Filter pathways to match the requested subset @@ -443,32 +447,19 @@ def _load_data(self): # If a requested pathway is missing, we'll use a zero column indices.append(-1) - # Subset and handle missing pathways (as zeros) p = len(self.target_pathway_names) subset_acts = np.zeros((acts.shape[0], p), dtype=np.float32) - subset_morans = np.zeros(p, dtype=np.float32) - for i, idx in enumerate(indices): if idx != -1: subset_acts[:, i] = acts[:, idx] - if pw_morans is not None: - subset_morans[i] = pw_morans[idx] self.pathway_activities = torch.tensor( subset_acts[mask_bool], dtype=torch.float32 ) - self.pathway_morans_i = torch.tensor( - subset_morans, dtype=torch.float32 - ) else: - # No subsetting: load all pathways self.pathway_activities = torch.tensor( acts[mask_bool], dtype=torch.float32 ) - if pw_morans is not None: - self.pathway_morans_i = torch.tensor( - pw_morans, dtype=torch.float32 - ) def __len__(self): return 1 if self.whole_slide_mode else len(self.coords) @@ -493,11 +484,6 @@ def _getitem_whole_slide(self): else None ), co, - ( - self.pathway_morans_i.clone() - if self.pathway_morans_i is not None - else None - ), ) def _getitem_patch(self, idx): @@ -561,8 +547,7 @@ def _getitem_patch(self, idx): if self.pathway_activities is not None else None ) - pathway_morans = self.pathway_morans_i # (P,) or None — same for all spots - return feats, target_genes, pathway_acts, rel_coords, pathway_morans + return feats, target_genes, pathway_acts, rel_coords # --------------------------------------------------------------------------- @@ -571,18 +556,17 @@ def _getitem_patch(self, idx): def collate_fn_patch(batch): - """Collate ``(feats, genes, pathway_acts, coords, pathway_morans)`` tuples. + """Collate ``(feats, genes, pathway_acts, coords)`` tuples. - Handles ``pathway_acts=None`` and ``pathway_morans=None`` (when no - pathway targets dir is configured) by passing ``None`` through. + Handles ``pathway_acts=None`` (when no pathway targets dir is configured) + by passing ``None`` through. Args: - batch: List of ``(feats, genes, pathway_acts, coords, pathway_morans)`` - tuples. + batch: List of ``(feats, genes, pathway_acts, coords)`` tuples. Returns: - tuple: ``(feats, genes, pathway_acts, coords, pathway_morans)`` where - ``pathway_acts`` and ``pathway_morans`` are stacked tensors or ``None``. + tuple: ``(feats, genes, pathway_acts, coords)`` where ``pathway_acts`` + is a stacked tensor or ``None``. """ feats = torch.stack([item[0] for item in batch]) genes = ( @@ -591,9 +575,7 @@ def collate_fn_patch(batch): has_pathways = batch[0][2] is not None pathways = torch.stack([item[2] for item in batch]) if has_pathways else None coords = torch.stack([item[3] for item in batch]) - has_morans = batch[0][4] is not None - morans = torch.stack([item[4] for item in batch]) if has_morans else None - return feats, genes, pathways, coords, morans + return feats, genes, pathways, coords # --------------------------------------------------------------------------- @@ -692,23 +674,20 @@ def collate_fn_ws(batch): """Pad variable-length slides to the longest in the batch. Args: - batch: List of ``(feats, genes, pathway_acts, coords, pathway_morans)`` - tuples where each tensor has a variable first dimension - (number of patches). ``pathway_acts`` and - ``pathway_morans`` may be ``None``. + batch: List of ``(feats, genes, pathway_acts, coords)`` tuples + where each tensor has a variable first dimension (number + of patches). ``pathway_acts`` may be ``None``. Returns: tuple: ``(padded_feats, padded_genes, padded_pathway_acts, - padded_coords, mask, pathway_morans)`` where ``mask`` is - ``True`` for padding positions. ``padded_pathway_acts`` and - ``pathway_morans`` are ``None`` when not loaded. + padded_coords, mask)`` where ``mask`` is ``True`` for padding + positions. ``padded_pathway_acts`` is ``None`` when not + loaded. """ - # lengths and common dims lengths = [item[0].shape[0] for item in batch] max_len = max(lengths) d_dim = batch[0][0].shape[1] has_pathways = batch[0][2] is not None - has_morans = batch[0][4] is not None bs = len(batch) padded_feats = torch.zeros(bs, max_len, d_dim) @@ -722,13 +701,7 @@ def collate_fn_ws(batch): else: padded_pathways = None - # pathway_morans is per-sample (P,), no spatial padding needed - if has_morans: - stacked_morans = torch.stack([item[4] for item in batch]) # (B, P) - else: - stacked_morans = None - - for i, (f, g, pw, c, _pm) in enumerate(batch): + for i, (f, g, pw, c) in enumerate(batch): l = lengths[i] padded_feats[i, :l] = f padded_coords[i, :l] = c @@ -742,7 +715,6 @@ def collate_fn_ws(batch): padded_pathways, padded_coords, mask, - stacked_morans, ) return DataLoader( diff --git a/src/spatial_transcript_former/train.py b/src/spatial_transcript_former/train.py index 72ddbd2..e8a9857 100644 --- a/src/spatial_transcript_former/train.py +++ b/src/spatial_transcript_former/train.py @@ -15,11 +15,6 @@ from spatial_transcript_former.config import get_config from spatial_transcript_former.models import HE2RNA, ViT_ST, SpatialTranscriptFormer from spatial_transcript_former.utils import set_seed -from spatial_transcript_former.training.losses import ( - PCCLoss, - CompositeLoss, - MaskedMSELoss, -) from spatial_transcript_former.training.engine import train_one_epoch, validate from spatial_transcript_former.training.experiment_logger import ExperimentLogger from spatial_transcript_former.visualization import run_inference_plot @@ -34,6 +29,48 @@ save_checkpoint, load_checkpoint, ) +from spatial_transcript_former.checkpoint import save_pretrained +from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( + PATHWAY_FILE_VERSION, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _resolve_pathway_names(args): + """Best-effort recovery of pathway names for ``save_pretrained``. + + Order: + 1. ``args.pathways`` if it's a non-empty list. + 2. The ``pathway_names`` dataset of the first .h5 in + ``args.pathway_targets_dir``. + 3. ``None`` (skip writing pathway_names.json). + """ + explicit = getattr(args, "pathways", None) + if explicit and isinstance(explicit, (list, tuple)) and len(explicit) > 0: + return list(explicit) + + targets_dir = getattr(args, "pathway_targets_dir", None) + if targets_dir and os.path.isdir(targets_dir): + for fname in sorted(os.listdir(targets_dir)): + if not fname.endswith(".h5"): + continue + try: + import h5py + + with h5py.File(os.path.join(targets_dir, fname), "r") as f: + if "pathway_names" in f: + return [ + n.decode() if isinstance(n, bytes) else n + for n in f["pathway_names"][:] + ] + except Exception: + pass + break # only inspect the first .h5 + return None + # --------------------------------------------------------------------------- # Main @@ -92,6 +129,11 @@ def main(): scaler = torch.amp.GradScaler("cuda") if args.use_amp else None print(f"Loss: {criterion.__class__.__name__}") print(f"LR schedule: {warmup_epochs}-epoch warmup -> cosine annealing to 1e-6") + print( + f"Targets: pathway_format_version={PATHWAY_FILE_VERSION} " + "(mean log1p CP10k of pathway members). " + "Validation MAE/loss are in those units; best-model selection uses CCC." + ) # 3. Output & Logger os.makedirs(args.output_dir, exist_ok=True) @@ -99,10 +141,13 @@ def main(): logger = ExperimentLogger(args.output_dir, config_dict) # 4. Resume - start_epoch, best_val_loss = 0, float("inf") + # ``best_val_metric`` tracks the highest val_ccc seen so far (CCC is + # higher-is-better and pathway-scale-invariant; preferable to MSE-based + # selection now that targets live in raw log1p CP10k units). + start_epoch, best_val_metric = 0, -float("inf") schedulers = {"main": main_scheduler} if args.resume: - start_epoch, best_val_loss, loaded_schedulers = load_checkpoint( + start_epoch, best_val_metric, loaded_schedulers = load_checkpoint( model, optimizer, scaler, schedulers, args.output_dir, args.model, device ) @@ -156,6 +201,8 @@ def main(): epoch_row["val_mae"] = round(val_metrics["val_mae"], 4) if val_metrics.get("val_pcc") is not None: epoch_row["val_pcc"] = round(val_metrics["val_pcc"], 4) + if val_metrics.get("val_ccc") is not None: + epoch_row["val_ccc"] = round(val_metrics["val_ccc"], 4) if val_metrics.get("pred_variance") is not None: epoch_row["pred_variance"] = round(val_metrics["pred_variance"], 6) if val_metrics.get("spatial_coherence") is not None: @@ -179,12 +226,28 @@ def main(): logger.log_epoch(epoch + 1, epoch_row) - # Save best - if val_loss < best_val_loss: - best_val_loss = val_loss + # Save best — selection driven by CCC (higher is better) + val_ccc = val_metrics.get("val_ccc") + if val_ccc is not None and val_ccc > best_val_metric: + best_val_metric = val_ccc + + # Legacy state_dict path (kept for tools that still load .pth directly) best_path = os.path.join(args.output_dir, f"best_model_{args.model}.pth") torch.save(model.state_dict(), best_path) - print(f"Saved best model -> {best_path}") + + # Self-contained checkpoint directory (config.json + model.pth + + # optional pathway_names.json) so inference can rebuild the model + # without re-specifying architecture flags. + best_dir = os.path.join(args.output_dir, f"best_{args.model}") + try: + save_pretrained( + model, + best_dir, + pathway_names=_resolve_pathway_names(args), + ) + except Exception as e: + print(f" (skipped save_pretrained bundle: {e})") + print(f"Saved best model (val_ccc={val_ccc:.4f}) -> {best_path}") # Save latest save_checkpoint( @@ -193,7 +256,7 @@ def main(): scaler, schedulers, epoch, - best_val_loss, + best_val_metric, args.output_dir, args.model, ) @@ -205,8 +268,11 @@ def main(): run_inference_plot(model, args, vis_id, epoch + 1, device) # 6. Finalize - logger.finalize(best_val_loss) - print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}") + logger.finalize(best_val_metric) + if best_val_metric == -float("inf"): + print("\nTraining complete. No valid CCC was recorded.") + else: + print(f"\nTraining complete. Best val CCC: {best_val_metric:.4f}") if __name__ == "__main__": diff --git a/src/spatial_transcript_former/training/arguments.py b/src/spatial_transcript_former/training/arguments.py index f35daed..e2ff8a4 100644 --- a/src/spatial_transcript_former/training/arguments.py +++ b/src/spatial_transcript_former/training/arguments.py @@ -40,16 +40,30 @@ def parse_args(): choices=[ "mse", "pcc", + "ccc", "mse_pcc", - "poisson", - "logcosh", + "mse_ccc", + "mse_ccc_clip", + "mse_huber", ], ) parser.add_argument( "--pcc-weight", type=float, default=1.0, - help="Weight for PCC term in mse_pcc loss", + help="Weight for PCC/CCC term in composite losses", + ) + parser.add_argument( + "--clip-weight", + type=float, + default=0.5, + help="Weight for CLIP alignment term (mse_ccc_clip only)", + ) + parser.add_argument( + "--clip-temp", + type=float, + default=0.07, + help="Temperature τ for CLIP alignment loss", ) parser.add_argument( "--pathway-targets-dir", @@ -57,13 +71,6 @@ def parse_args(): default=None, help="Directory of pre-computed pathway activity .h5 files", ) - parser.add_argument( - "--morans-pathway-weight", - action="store_true", - help="Weight MSE loss per-pathway by Moran's I spatial autocorrelation. " - "Requires pathway .h5 files to contain pathway_morans_i " - "(re-run stf-compute-pathways --overwrite to add them).", - ) # Model g = parser.add_argument_group("Model") @@ -124,6 +131,13 @@ def parse_args(): ) g.add_argument("--compile", action="store_true") g.add_argument("--resume", action="store_true") + g.add_argument( + "--run-name", + type=str, + default=None, + help="Name used for checkpoint files and logs (defaults to --model if unset). " + "Set automatically by run_preset.py to the preset name.", + ) # Advanced g = parser.add_argument_group("Advanced") diff --git a/src/spatial_transcript_former/training/builder.py b/src/spatial_transcript_former/training/builder.py index 61633f2..ddaba85 100644 --- a/src/spatial_transcript_former/training/builder.py +++ b/src/spatial_transcript_former/training/builder.py @@ -3,9 +3,12 @@ import torch.nn as nn from spatial_transcript_former.models import HE2RNA, ViT_ST, SpatialTranscriptFormer from spatial_transcript_former.training.losses import ( - PCCLoss, + CCCLoss, + CLIPAlignmentLoss, CompositeLoss, + MaskedHuberLoss, MaskedMSELoss, + PCCLoss, ) @@ -83,14 +86,28 @@ def setup_model(args, device): def setup_criterion(args): """Create loss function from CLI args.""" + clip_w = getattr(args, "clip_weight", 0.0) + clip_t = getattr(args, "clip_temp", 0.07) + if args.loss == "pcc": return PCCLoss() + elif args.loss == "ccc": + return CCCLoss() elif args.loss == "mse_pcc": return CompositeLoss(alpha=args.pcc_weight) - elif args.loss == "poisson": - return nn.PoissonNLLLoss(log_input=True) - elif args.loss == "logcosh": - print("Using HuberLoss as proxy for LogCosh") - return nn.HuberLoss() + elif args.loss == "mse_ccc": + return CompositeLoss(alpha=args.pcc_weight, pcc_type="ccc") + elif args.loss == "mse_ccc_clip": + # CLIP term here is the batch-discriminative regulariser in + # pathway-output space (see CLIPAlignmentLoss docstring). Available + # for opt-in experiments; not part of the current default track. + return CompositeLoss( + alpha=args.pcc_weight, + pcc_type="ccc", + clip_weight=clip_w or 0.5, + clip_temperature=clip_t, + ) + elif args.loss == "mse_huber": + return CompositeLoss(alpha=args.pcc_weight, mse_type="huber", pcc_type="ccc") else: return MaskedMSELoss() diff --git a/src/spatial_transcript_former/training/checkpoint.py b/src/spatial_transcript_former/training/checkpoint.py index e5ff772..1934af3 100644 --- a/src/spatial_transcript_former/training/checkpoint.py +++ b/src/spatial_transcript_former/training/checkpoint.py @@ -1,16 +1,27 @@ import os import torch +from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( + PATHWAY_FILE_VERSION, +) + def save_checkpoint( - model, optimizer, scaler, schedulers, epoch, best_val_loss, output_dir, model_name + model, optimizer, scaler, schedulers, epoch, best_val_metric, output_dir, model_name ): - """Save training state for resuming.""" + """Save training state for resuming. + + ``best_val_metric`` semantics: highest val_ccc seen so far (higher is + better). Stored as ``best_val_metric`` and mirrored into the legacy + ``best_val_loss`` key so older tools that read the field still parse it. + """ save_dict = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), - "best_val_loss": best_val_loss, + "best_val_metric": best_val_metric, + "best_val_loss": best_val_metric, # legacy key, same value + "pathway_format_version": PATHWAY_FILE_VERSION, } if scaler is not None: save_dict["scaler_state_dict"] = scaler.state_dict() @@ -29,12 +40,12 @@ def load_checkpoint( Load checkpoint if it exists. Returns: - tuple: (start_epoch, best_val_loss, loaded_schedulers) + tuple: (start_epoch, best_val_metric, loaded_schedulers) """ ckpt_path = os.path.join(output_dir, f"latest_model_{model_name}.pth") if not os.path.exists(ckpt_path): print(f"No checkpoint found at {ckpt_path}. Starting from scratch.") - return 0, float("inf"), False + return 0, -float("inf"), False print(f"Resuming from {ckpt_path}...") try: @@ -43,7 +54,7 @@ def load_checkpoint( print( f"Failed to load checkpoint at {ckpt_path} due to error: {e}. Starting from scratch." ) - return 0, float("inf"), False + return 0, -float("inf"), False incompatible_keys = model.load_state_dict( checkpoint["model_state_dict"], strict=False @@ -67,10 +78,27 @@ def load_checkpoint( print( f"Failed to load optimizer/scheduler states due to architecture change ({e}). Starting from scratch." ) - return 0, float("inf"), False + return 0, -float("inf"), False + + # Refuse to resume if the targets the checkpoint was trained against + # don't match the current preprocessing format. + ckpt_pfv = checkpoint.get("pathway_format_version", None) + if ckpt_pfv is not None and int(ckpt_pfv) != PATHWAY_FILE_VERSION: + print( + f"Checkpoint pathway_format_version={ckpt_pfv} does not match " + f"current {PATHWAY_FILE_VERSION}. Refusing to resume — " + "regenerate targets and start fresh." + ) + return 0, -float("inf"), False start_epoch = checkpoint.get("epoch", -1) + 1 - best_val_loss = checkpoint.get("best_val_loss", float("inf")) + best_val_metric = checkpoint.get( + "best_val_metric", checkpoint.get("best_val_loss", -float("inf")) + ) + # Old checkpoints stored loss (lower is better) under best_val_loss; once + # that lands in our higher-is-better tracker it's unusable. Reset it. + if "best_val_metric" not in checkpoint: + best_val_metric = -float("inf") print(f"Resumed at epoch {start_epoch + 1}") - return start_epoch, best_val_loss, loaded_schedulers + return start_epoch, best_val_metric, loaded_schedulers diff --git a/src/spatial_transcript_former/training/engine.py b/src/spatial_transcript_former/training/engine.py index 8fa7a3b..c2221f3 100644 --- a/src/spatial_transcript_former/training/engine.py +++ b/src/spatial_transcript_former/training/engine.py @@ -12,33 +12,10 @@ from tqdm import tqdm from spatial_transcript_former.models import SpatialTranscriptFormer from spatial_transcript_former.data.spatial_stats import spatial_coherence_score -from spatial_transcript_former.training.losses import CompositeLoss, MaskedMSELoss -def _prepare_pathway_weights(pathway_morans, device): - """Average per-sample Moran's I weights into a single (P,) vector. - - Args: - pathway_morans: (B, P) tensor of per-sample Moran's I weights, or None. - device: Target device. - - Returns: - (P,) tensor or None. - """ - if pathway_morans is None: - return None - pw = pathway_morans.to(device) - if pw.dim() == 2: - pw = pw.mean(dim=0) # Average across batch samples - return pw - - -def _criterion_call(criterion, preds, targets, mask=None, pathway_weights=None): - """Call the criterion, passing pathway_weights only if supported.""" - if pathway_weights is not None and isinstance( - criterion, (CompositeLoss, MaskedMSELoss) - ): - return criterion(preds, targets, mask=mask, pathway_weights=pathway_weights) +def _criterion_call(criterion, preds, targets, mask=None): + """Call the criterion, passing the mask only when provided.""" if mask is not None: return criterion(preds, targets, mask=mask) return criterion(preds, targets) @@ -98,8 +75,8 @@ def train_one_epoch( if whole_slide: for batch_idx, batch in enumerate(pbar): - # Unpack: (feats, None, pathway_targets, coords, mask, pathway_morans) - feats, _, pathway_targets, coords, mask, pathway_morans = batch + # Unpack: (feats, None, pathway_targets, coords, mask) + feats, _, pathway_targets, coords, mask = batch feats = feats.to(device, non_blocking=True) coords = coords.to(device, non_blocking=True) mask = mask.to(device, non_blocking=True) @@ -108,7 +85,6 @@ def train_one_epoch( "pathway_targets is None, but training now requires pathway targets." ) pathway_targets = pathway_targets.to(device, non_blocking=True) - pw = _prepare_pathway_weights(pathway_morans, device) with torch.amp.autocast("cuda", enabled=scaler is not None): if isinstance(model, SpatialTranscriptFormer) and not getattr( @@ -125,7 +101,6 @@ def train_one_epoch( preds, pathway_targets, mask=mask, - pathway_weights=pw, ) else: preds = model(feats) @@ -134,7 +109,6 @@ def train_one_epoch( criterion, preds, bag_target, - pathway_weights=pw, ) loss = loss / grad_accum_steps @@ -148,8 +122,8 @@ def train_one_epoch( pbar.set_postfix({"loss": f"{current_loss:.4f}"}) else: for batch_idx, batch in enumerate(pbar): - # Unpack: (images, None, pathway_targets, rel_coords, pathway_morans) - images, _, pathway_targets, rel_coords, pathway_morans = batch + # Unpack: (images, None, pathway_targets, rel_coords) + images, _, pathway_targets, rel_coords = batch images = images.to(device, non_blocking=True) rel_coords = rel_coords.to(device, non_blocking=True) if pathway_targets is None: @@ -157,7 +131,6 @@ def train_one_epoch( "pathway_targets is None, but training now requires pathway targets." ) pathway_targets = pathway_targets.to(device, non_blocking=True) - pw = _prepare_pathway_weights(pathway_morans, device) with torch.amp.autocast("cuda", enabled=scaler is not None): if isinstance(model, SpatialTranscriptFormer): @@ -169,7 +142,6 @@ def train_one_epoch( criterion, outputs, pathway_targets, - pathway_weights=pw, ) loss = loss / grad_accum_steps @@ -188,13 +160,24 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) """ Validate the model. - Returns: - dict: {"val_loss": float, "attn_correlation": float or None} + Returns a metrics dict including ``val_loss``, ``val_mae``, ``val_pcc``, + ``val_ccc`` (averaged across pathways and slides), ``val_baseline_mae`` + (the MAE of a constant-mean-of-targets predictor — a zero-information + skill baseline), ``val_pcc_per_pathway`` / ``val_ccc_per_pathway`` + (per-pathway lists of slide-averaged values, ``None`` for pathways with + no spatial variance), plus ``pred_variance``, ``spatial_coherence`` and + ``attn_correlation``. """ model.eval() running_loss = 0.0 running_mae = 0.0 + running_baseline_mae = 0.0 + n_baseline_batches = 0 pcc_list = [] + ccc_list = [] + # Per-pathway accumulators: index -> list of slide-level values + per_pathway_pcc: dict = {} + per_pathway_ccc: dict = {} pred_var_list = [] attn_correlations = [] spatial_coherence_list = [] @@ -204,24 +187,22 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) loader, desc="Validation", file=sys.stdout, dynamic_ncols=True ): if whole_slide: - # Unpack: (feats, None, pathway_targets, coords, mask, pathway_morans) - feats, _, pathway_targets, coords, mask, pathway_morans = batch + # Unpack: (feats, None, pathway_targets, coords, mask) + feats, _, pathway_targets, coords, mask = batch feats = feats.to(device, non_blocking=True) coords = coords.to(device, non_blocking=True) mask = mask.to(device, non_blocking=True) if pathway_targets is None: raise ValueError("pathway_targets is None in validation.") pathway_targets = pathway_targets.to(device, non_blocking=True) - pw = _prepare_pathway_weights(pathway_morans, device) else: - # Unpack: (images, None, pathway_targets, rel_coords, pathway_morans) - images, _, pathway_targets, rel_coords, pathway_morans = batch + # Unpack: (images, None, pathway_targets, rel_coords) + images, _, pathway_targets, rel_coords = batch images = images.to(device, non_blocking=True) rel_coords = rel_coords.to(device, non_blocking=True) if pathway_targets is None: raise ValueError("pathway_targets is None in validation.") pathway_targets = pathway_targets.to(device, non_blocking=True) - pw = _prepare_pathway_weights(pathway_morans, device) with torch.amp.autocast("cuda", enabled=use_amp): attn = None @@ -265,17 +246,15 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) outputs, targets, mask=mask, - pathway_weights=pw, ) else: loss = _criterion_call( criterion, outputs, targets, - pathway_weights=pw, ) - # --- Interpretability Metrics (MAE & PCC) --- + # --- Interpretability Metrics (MAE, PCC, CCC) --- eval_preds = outputs mae_diff = torch.abs(eval_preds - targets) if ( @@ -286,59 +265,111 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) ): valid_mask = ~mask.unsqueeze(-1).expand_as(mae_diff) mae_val = (mae_diff * valid_mask.float()).sum() / valid_mask.sum() + + # Skill baseline: predict the per-slide per-pathway mean + # constant. MAE of that predictor is the natural reference + # for "zero information" performance in the same units. + valid_counts = ( + (~mask).float().sum(dim=1, keepdim=True).clamp(min=1.0) + ) + target_means = (targets * (~mask).float().unsqueeze(-1)).sum( + dim=1, keepdim=True + ) / valid_counts.unsqueeze(-1) + baseline_diff = torch.abs(targets - target_means) + baseline_mae_val = ( + baseline_diff * valid_mask.float() + ).sum() / valid_mask.sum() else: mae_val = mae_diff.mean() + target_means = targets.mean(dim=0, keepdim=True) + baseline_mae_val = torch.abs(targets - target_means).mean() + + # PCC and CCC — computed for all model/mode combinations + if torch.isfinite(eval_preds).all() and torch.isfinite(targets).all(): + if whole_slide: + for b_idx in range(eval_preds.shape[0]): + p_slide = eval_preds[b_idx] # (N, G) + t_slide = targets[b_idx] # (N, G) + valid_idx = ~mask[b_idx] + p_slide = p_slide[valid_idx] # (V, G) + t_slide = t_slide[valid_idx] # (V, G) + + if p_slide.shape[0] >= 2: + pred_mean = p_slide.mean(dim=0) # (G,) + tgt_mean = t_slide.mean(dim=0) # (G,) + vx = p_slide - pred_mean + vy = t_slide - tgt_mean + cov = (vx * vy).sum(dim=0) # (G,) + var_x = (vx**2).sum(dim=0) # (G,) + var_y = (vy**2).sum(dim=0) # (G,) + + corr = cov / ( + torch.sqrt(var_x + 1e-8) * torch.sqrt(var_y + 1e-8) + ) + N_v = p_slide.shape[0] + mean_diff_sq = (pred_mean - tgt_mean) ** 2 + ccc_vals = (2 * cov) / ( + var_x + var_y + N_v * mean_diff_sq + 1e-8 + ) - if ( - torch.isfinite(eval_preds).all() - and torch.isfinite(targets).all() - ): - # Calculate Spatial PCC (across spots N, for each gene G independently) - # outputs/targets are (B, N, G) for whole_slide or (B, G) for patch - if whole_slide: - # Iterate over batches to correlate spatially for each slide - B = eval_preds.shape[0] - for b_idx in range(B): - p_slide = eval_preds[b_idx] # (N, G) - t_slide = targets[b_idx] # (N, G) - - valid_idx = ~mask[b_idx] - p_slide = p_slide[valid_idx] # (V, G) - t_slide = t_slide[valid_idx] # (V, G) - - if p_slide.shape[0] >= 2: - vx = p_slide - p_slide.mean(dim=0, keepdim=True) - vy = t_slide - t_slide.mean(dim=0, keepdim=True) - num = torch.sum(vx * vy, dim=0) # (G,) - den = torch.sqrt( - torch.sum(vx**2, dim=0) + 1e-8 - ) * torch.sqrt(torch.sum(vy**2, dim=0) + 1e-8) - corr = num / den - - active_genes = torch.std(t_slide, dim=0) > 1e-6 - if active_genes.any(): - valid_corrs = corr[active_genes] - valid_corrs = valid_corrs[ - torch.isfinite(valid_corrs) - ] - if len(valid_corrs) > 0: - pcc_list.append(valid_corrs.mean().item()) - else: - # Patch level (B, G). Correlate across the batch B (which is spatial patches) - vx = eval_preds - eval_preds.mean(dim=0, keepdim=True) - vy = targets - targets.mean(dim=0, keepdim=True) - num = torch.sum(vx * vy, dim=0) - den = torch.sqrt( - torch.sum(vx**2, dim=0) + 1e-8 - ) * torch.sqrt(torch.sum(vy**2, dim=0) + 1e-8) - corr = num / den - - active_genes = torch.std(targets, dim=0) > 1e-6 - if active_genes.any(): - valid_corrs = corr[active_genes] - valid_corrs = valid_corrs[torch.isfinite(valid_corrs)] - if len(valid_corrs) > 0: - pcc_list.append(valid_corrs.mean().item()) + valid_genes = torch.std(t_slide, dim=0) > 1e-6 + if valid_genes.any(): + vc = corr[valid_genes] + vc = vc[torch.isfinite(vc)] + if len(vc) > 0: + pcc_list.append(vc.mean().item()) + vk = ccc_vals[valid_genes] + vk = vk[torch.isfinite(vk)] + if len(vk) > 0: + ccc_list.append(vk.mean().item()) + + # Per-pathway accumulation (slide-level entries) + valid_idx_genes = torch.where(valid_genes)[0].tolist() + for g in valid_idx_genes: + c = corr[g].item() + k = ccc_vals[g].item() + if np.isfinite(c): + per_pathway_pcc.setdefault(g, []).append(c) + if np.isfinite(k): + per_pathway_ccc.setdefault(g, []).append(k) + else: + pred_mean = eval_preds.mean(dim=0) # (G,) + tgt_mean = targets.mean(dim=0) # (G,) + vx = eval_preds - pred_mean + vy = targets - tgt_mean + cov = (vx * vy).sum(dim=0) + var_x = (vx**2).sum(dim=0) + var_y = (vy**2).sum(dim=0) + + corr = cov / ( + torch.sqrt(var_x + 1e-8) * torch.sqrt(var_y + 1e-8) + ) + B_size = eval_preds.shape[0] + mean_diff_sq = (pred_mean - tgt_mean) ** 2 + ccc_vals = (2 * cov) / ( + var_x + var_y + B_size * mean_diff_sq + 1e-8 + ) + + valid_genes = torch.std(targets, dim=0) > 1e-6 + if valid_genes.any(): + vc = corr[valid_genes] + vc = vc[torch.isfinite(vc)] + if len(vc) > 0: + pcc_list.append(vc.mean().item()) + vk = ccc_vals[valid_genes] + vk = vk[torch.isfinite(vk)] + if len(vk) > 0: + ccc_list.append(vk.mean().item()) + + # Per-pathway accumulation (one entry per validation pass) + valid_idx_genes = torch.where(valid_genes)[0].tolist() + for g in valid_idx_genes: + c = corr[g].item() + k = ccc_vals[g].item() + if np.isfinite(c): + per_pathway_pcc.setdefault(g, []).append(c) + if np.isfinite(k): + per_pathway_ccc.setdefault(g, []).append(k) # Spatial Attention Correlation (MIL weak supervision study) if attn is not None and whole_slide: @@ -354,6 +385,8 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) running_loss += loss.item() running_mae += mae_val.item() + running_baseline_mae += baseline_mae_val.item() + n_baseline_batches += 1 # Track prediction variance (collapse detector) with torch.no_grad(): @@ -387,7 +420,11 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) avg_loss = running_loss / len(loader) avg_mae = running_mae / len(loader) + avg_baseline_mae = ( + running_baseline_mae / n_baseline_batches if n_baseline_batches else None + ) avg_pcc = sum(pcc_list) / len(pcc_list) if pcc_list else None + avg_ccc = sum(ccc_list) / len(ccc_list) if ccc_list else None avg_corr = ( sum(attn_correlations) / len(attn_correlations) if attn_correlations else None ) @@ -399,8 +436,18 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) else None ) + # Collapse per-pathway accumulators to a single number per pathway. + pcc_per_pathway = {g: float(np.mean(v)) for g, v in per_pathway_pcc.items()} + ccc_per_pathway = {g: float(np.mean(v)) for g, v in per_pathway_ccc.items()} + + corr_line = f"Validation MAE: {avg_mae:.4f}" + if avg_baseline_mae is not None: + corr_line += f" (baseline {avg_baseline_mae:.4f})" if avg_pcc is not None: - print(f"Validation MAE: {avg_mae:.4f} | PCC: {avg_pcc:.4f}") + corr_line += f" | PCC: {avg_pcc:.4f}" + if avg_ccc is not None: + corr_line += f" | CCC: {avg_ccc:.4f}" + print(corr_line) if avg_pred_var is not None: print(f"Prediction Variance: {avg_pred_var:.6f}") if avg_spatial_coherence is not None: @@ -408,10 +455,23 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) if avg_corr is not None: print(f"Spatial Attention Correlation: {avg_corr:.4f}") + # Per-pathway breakdown — top-3 best/worst by CCC, when available. + if ccc_per_pathway: + ranked = sorted(ccc_per_pathway.items(), key=lambda kv: kv[1]) + worst = ranked[:3] + best = ranked[-3:][::-1] + fmt = lambda items: ", ".join(f"p{g}={v:.3f}" for g, v in items) + print(f" best CCC pathways: {fmt(best)}") + print(f" worst CCC pathways: {fmt(worst)}") + return { "val_loss": avg_loss, "val_mae": avg_mae, + "val_baseline_mae": avg_baseline_mae, "val_pcc": avg_pcc, + "val_ccc": avg_ccc, + "val_pcc_per_pathway": pcc_per_pathway, + "val_ccc_per_pathway": ccc_per_pathway, "pred_variance": avg_pred_var, "spatial_coherence": avg_spatial_coherence, "attn_correlation": avg_corr, diff --git a/src/spatial_transcript_former/training/losses.py b/src/spatial_transcript_former/training/losses.py index 94690f8..d95e4ae 100644 --- a/src/spatial_transcript_former/training/losses.py +++ b/src/spatial_transcript_former/training/losses.py @@ -1,13 +1,14 @@ """ Loss functions for SpatialTranscriptFormer. -Supports MSE, PCC, and composite MSE+PCC objectives. +Supports MSE, PCC, CCC, Huber, CLIP-alignment, and composite objectives. All losses handle both patch-level (B, G) and dense (B, N, G) inputs, with optional masking for padded positions in whole-slide mode. """ import torch import torch.nn as nn +import torch.nn.functional as F class PCCLoss(nn.Module): @@ -103,23 +104,98 @@ def forward(self, preds, target, mask=None): return 1 - cost.mean() +class CCCLoss(PCCLoss): + """ + Concordance Correlation Coefficient Loss. + + CCC = 2 * cov(pred, target) / (var(pred) + var(target) + (mean(pred) - mean(target))^2) + Loss = 1 - mean(CCC) + + Strictly more informative than PCC: a prediction that is correctly correlated + but systematically offset (wrong mean or variance) will have CCC < PCC. + """ + + def forward(self, preds, target, mask=None): + """ + Args: + preds: (B, G) or (B, N, G) + target: (B, G) or (B, N, G) + mask: (B, N) boolean, True = padded (ignore). Optional. + + Returns: + Scalar loss = 1 - mean(CCC). + """ + if preds.dim() == 2: + preds = preds.unsqueeze(1) + target = target.unsqueeze(1) + if mask is not None: + mask = mask.unsqueeze(1) + + B, N, G = preds.shape + + if N == 1: + preds = preds.squeeze(1) + target = target.squeeze(1) + if mask is not None: + valid = ~mask.squeeze(1) + preds = preds[valid] + target = target[valid] + + if preds.shape[0] < 2: + return torch.tensor(0.0, device=preds.device, requires_grad=True) + + pred_mean = preds.mean(dim=0) + target_mean = target.mean(dim=0) + vx = preds - pred_mean + vy = target - target_mean + cov = (vx * vy).sum(dim=0) + var_x = (vx**2).sum(dim=0) + var_y = (vy**2).sum(dim=0) + mean_diff_sq = (pred_mean - target_mean) ** 2 + ccc = (2 * cov) / (var_x + var_y + mean_diff_sq + self.eps) + return 1 - ccc.mean() + + if mask is not None: + valid = ~mask.unsqueeze(-1) + preds = preds * valid.float() + target = target * valid.float() + valid_counts = valid.sum(dim=1, keepdim=True).clamp(min=1.0) + else: + valid_counts = torch.tensor(N, dtype=torch.float32, device=preds.device) + + pred_means = preds.sum(dim=1, keepdim=True) / valid_counts + target_means = target.sum(dim=1, keepdim=True) / valid_counts + + vx = preds - pred_means + vy = target - target_means + + if mask is not None: + vx = vx * valid.float() + vy = vy * valid.float() + + cov = (vx * vy).sum(dim=1) # (B, G) + var_x = (vx**2).sum(dim=1) # (B, G) + var_y = (vy**2).sum(dim=1) # (B, G) + + mean_diff_sq = (pred_means.squeeze(1) - target_means.squeeze(1)) ** 2 # (B, G) + ccc = (2 * cov) / (var_x + var_y + mean_diff_sq + self.eps) # (B, G) + + return 1 - ccc.mean() + + class MaskedMSELoss(nn.Module): """ - MSE loss with optional masking for padded positions and - optional per-pathway (per-gene) weighting. + MSE loss with optional masking for padded positions. When no mask is provided, behaves identically to nn.MSELoss(). - When pathway_weights is provided, each output dimension's MSE - is scaled by its weight before averaging. """ - def forward(self, preds, target, mask=None, pathway_weights=None): + def forward(self, preds, target, mask=None): """ Args: preds: (B, G) or (B, N, G) target: (B, G) or (B, N, G) mask: (B, N) boolean, True = padded (ignore). Optional. - pathway_weights: (G,) float tensor of per-pathway weights. Optional. Returns: Scalar MSE loss over valid positions. @@ -127,60 +203,154 @@ def forward(self, preds, target, mask=None, pathway_weights=None): diff = (preds - target) ** 2 if mask is not None and preds.dim() == 3: - # Expand mask to gene dimension: (B, N) -> (B, N, G) valid = ~mask.unsqueeze(-1).expand_as(diff) - if pathway_weights is not None: - # weights: (G,) -> (1, 1, G) - w = pathway_weights.unsqueeze(0).unsqueeze(0) - weighted = diff * valid.float() * w - return weighted.sum() / (valid.float() * w).sum() return (diff * valid.float()).sum() / valid.sum() - if pathway_weights is not None: - # preds is (B, G): weight each pathway dimension - # weights: (G,) -> (1, G) - w = pathway_weights.unsqueeze(0) - return (diff * w).mean(dim=0).sum() / w.sum() + return diff.mean() + + +class MaskedHuberLoss(nn.Module): + """ + Huber (smooth L1) loss with optional masking for padded positions. + + More robust to outlier pathway activity values than MSE: quadratic near + zero, linear beyond delta. + + Args: + delta: Threshold between quadratic and linear regimes. Default 1.0. + """ + + def __init__(self, delta=1.0): + super().__init__() + self.delta = delta + + def forward(self, preds, target, mask=None): + """ + Args: + preds: (B, G) or (B, N, G) + target: (B, G) or (B, N, G) + mask: (B, N) boolean, True = padded (ignore). Optional. + + Returns: + Scalar Huber loss over valid positions. + """ + diff = F.huber_loss(preds, target, reduction="none", delta=self.delta) + + if mask is not None and preds.dim() == 3: + valid = ~mask.unsqueeze(-1).expand_as(diff) + return (diff * valid.float()).sum() / valid.sum() return diff.mean() +class CLIPAlignmentLoss(nn.Module): + """ + Batch-discriminative regulariser over predicted vs. target pathway vectors. + + Despite the name, this is **not** cross-modal CLIP. Both inputs already + live in pathway space: the model's predicted pathway vector and the + ground-truth pathway vector. We borrow CLIP's loss form — symmetric + cross-entropy on a B×B cosine-similarity matrix with diagonal labels — + to enforce "your prediction must rank its own target highest in the + batch." Acts as anti-collapse pressure on regression. + + Status (2026-04): kept available behind ``CompositeLoss(clip_weight=...)`` + and the ``mse_ccc_clip`` builder option, but **not part of the current + experimentation track**. Direct-regression baselines (``mse_ccc``) are + the focus while target construction and validation metrics stabilise. + + Deferred follow-up — true cross-modal CLIP. A separate variant would + align the model's image embedding (e.g. mean-pooled transformer output) + with a learned pathway-target encoder in a shared space, giving + retrieval / zero-shot capability. Not implemented; left as a future + direction once the regression baseline is solid. + + Args: + temperature: Softmax temperature τ. Default 0.07 (CLIP default). + """ + + def __init__(self, temperature=0.07): + super().__init__() + self.temperature = temperature + + def forward(self, preds, target, mask=None, **kwargs): + """ + Args: + preds: (B, G) or (B, N, G) + target: (B, G) or (B, N, G) + mask: (B, N) boolean, True = padded. Used to average dense inputs. + + Returns: + Scalar symmetric cross-entropy loss. + """ + if preds.dim() == 3: + if mask is not None: + valid = (~mask).unsqueeze(-1).float() + preds = (preds * valid).sum(1) / valid.sum(1).clamp(min=1) + target = (target * valid).sum(1) / valid.sum(1).clamp(min=1) + else: + preds = preds.mean(1) + target = target.mean(1) + + B = preds.shape[0] + if B < 2: + return torch.tensor(0.0, device=preds.device, requires_grad=True) + + norm_p = F.normalize(preds, dim=-1) # (B, G) + norm_t = F.normalize(target, dim=-1) # (B, G) + + sim = (norm_p @ norm_t.T) / self.temperature # (B, B) + labels = torch.arange(B, device=sim.device) + + return 0.5 * (F.cross_entropy(sim, labels) + F.cross_entropy(sim.T, labels)) + + class CompositeLoss(nn.Module): """ - Combined MSE + PCC loss for spatial gene expression prediction. + Combined MSE + PCC/CCC loss, with an optional CLIP alignment term. - L = MSE + alpha * (1 - PCC) + L = MSE + alpha * (1 - PCC/CCC) [+ clip_weight * L_CLIP] - MSE focuses on magnitude accuracy; PCC ensures spatial pattern - coherence across all genes equally (scale-invariant). Together - they address the gene expression imbalance problem. + MSE focuses on magnitude accuracy; PCC/CCC ensures spatial pattern + coherence; CLIP prevents variance collapse across the batch. Args: - alpha: Weight for the PCC term. Default 1.0. - eps: Numerical stability for PCC. Default 1e-8. + alpha: Weight for the PCC/CCC term. Default 1.0. + eps: Numerical stability for correlation losses. Default 1e-8. + mse_type: "mse" or "huber". Default "mse". + pcc_type: "pcc" or "ccc". Default "pcc". + clip_weight: Weight for CLIP alignment term. 0.0 disables it. + clip_temperature: Temperature τ for CLIP loss. Default 0.07. """ - def __init__(self, alpha=1.0, eps=1e-8, mse_type="mse"): + def __init__( + self, + alpha=1.0, + eps=1e-8, + mse_type="mse", + pcc_type="pcc", + clip_weight=0.0, + clip_temperature=0.07, + ): super().__init__() self.alpha = alpha - if mse_type == "huber": - self.mse = MaskedHuberLoss() - else: - self.mse = MaskedMSELoss() - self.pcc = PCCLoss(eps=eps) + self.mse = MaskedHuberLoss() if mse_type == "huber" else MaskedMSELoss() + self.pcc = CCCLoss(eps=eps) if pcc_type == "ccc" else PCCLoss(eps=eps) + self.clip = CLIPAlignmentLoss(clip_temperature) if clip_weight > 0 else None + self.clip_weight = clip_weight - def forward(self, preds, target, mask=None, pathway_weights=None): + def forward(self, preds, target, mask=None): """ Args: preds: (B, G) or (B, N, G) target: (B, G) or (B, N, G) mask: (B, N) boolean, True = padded (ignore). Optional. - pathway_weights: (G,) float tensor of per-pathway weights. Optional. - Applied to the MSE term only; PCC is left unweighted. Returns: Scalar composite loss. """ - mse_val = self.mse(preds, target, mask, pathway_weights=pathway_weights) - pcc_val = self.pcc(preds, target, mask) - return mse_val + self.alpha * pcc_val + loss = self.mse(preds, target, mask) + loss = loss + self.alpha * self.pcc(preds, target, mask) + if self.clip is not None: + loss = loss + self.clip_weight * self.clip(preds, target, mask) + return loss diff --git a/src/spatial_transcript_former/training/trainer.py b/src/spatial_transcript_former/training/trainer.py index 178d673..9a3bc70 100644 --- a/src/spatial_transcript_former/training/trainer.py +++ b/src/spatial_transcript_former/training/trainer.py @@ -22,6 +22,7 @@ trainer.save_pretrained("./release/v1/", gene_names=my_genes) """ +import math import os import time from typing import Any, Callable, Dict, List, Optional @@ -215,25 +216,19 @@ def __init__( # ------------------------------------------------------------------ def _build_scheduler(self): - warmup = optim.lr_scheduler.LinearLR( - self.optimizer, - start_factor=0.01, - total_iters=max(1, self.warmup_epochs), - ) - cosine = optim.lr_scheduler.CosineAnnealingLR( - self.optimizer, - T_max=max(1, self.epochs - self.warmup_epochs), - eta_min=1e-6, - ) - - if self.warmup_epochs > 0: - self.scheduler = optim.lr_scheduler.SequentialLR( - self.optimizer, - schedulers=[warmup, cosine], - milestones=[self.warmup_epochs], - ) - else: - self.scheduler = cosine + base_lr = self.optimizer.param_groups[0]["lr"] + eta_min = 1e-6 + warmup_epochs = self.warmup_epochs + total_epochs = self.epochs + + def lr_lambda(epoch): + if epoch < warmup_epochs: + return 0.01 + 0.99 * epoch / max(1, warmup_epochs) + progress = (epoch - warmup_epochs) / max(1, total_epochs - warmup_epochs) + cosine = 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + return (eta_min / base_lr) + (1.0 - eta_min / base_lr) * cosine + + self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) def _resume_from_checkpoint(self): schedulers = {"main": self.scheduler} diff --git a/src/spatial_transcript_former/visualization.py b/src/spatial_transcript_former/visualization.py index 7d52f3b..955727c 100644 --- a/src/spatial_transcript_former/visualization.py +++ b/src/spatial_transcript_former/visualization.py @@ -39,26 +39,39 @@ def _load_histology(h5ad_path): def _get_pathway_names(args, num_expected: int): - """Get pathway names to match the expected number of pathways.""" - try: - from spatial_transcript_former.data.pathways import ( - get_pathway_init, - MSIGDB_URLS, - ) - from spatial_transcript_former.data import GeneVocab - - vocab = GeneVocab.from_json(args.data_dir, num_genes=args.num_genes) - urls = [MSIGDB_URLS["hallmarks"]] - _, pw_names = get_pathway_init( - vocab.genes, - gmt_urls=urls, - verbose=False, - filter_names=getattr(args, "pathways", None), - ) - if len(pw_names) == num_expected: - return pw_names - except Exception: - pass + """Get pathway names to match the expected number of pathways. + + Resolution order: + 1. args.pathways (explicit subset, e.g. CRC_PATHWAYS) — exact, fast. + 2. First .h5 file found in args.pathway_targets_dir — reads the + pathway_names dataset written by stf-compute-pathways. + 3. Generic fallback: ["Pathway_0", ...]. + """ + # 1. Explicit pathway list (e.g. CRC run configured with --pathways) + explicit = getattr(args, "pathways", None) + if explicit and len(explicit) == num_expected: + return list(explicit) + + # 2. Load names from a pathway activity .h5 file + targets_dir = getattr(args, "pathway_targets_dir", None) + if targets_dir and os.path.isdir(targets_dir): + for fname in os.listdir(targets_dir): + if not fname.endswith(".h5"): + continue + try: + with h5py.File(os.path.join(targets_dir, fname), "r") as f: + if "pathway_names" in f: + names = [ + n.decode() if isinstance(n, bytes) else n + for n in f["pathway_names"][:] + ] + if len(names) == num_expected: + return names + except Exception: + pass + break # Only inspect the first .h5 found + + # 3. Generic fallback return [f"Pathway_{i}" for i in range(num_expected)] @@ -83,13 +96,13 @@ def run_inference_plot(model, args, sample_id, epoch, device): with torch.no_grad(): for batch in val_loader: if args.whole_slide: - image_features, _, target, coords, mask, _ = batch + image_features, _, target, coords, mask = batch image_features = image_features.to(device) coords = coords.to(device) mask = mask.to(device) target = target.to(device) else: - image_features, _, target, coords, _ = batch + image_features, _, target, coords = batch image_features = image_features.to(device) coords = coords.to(device) mask = torch.ones(target.shape[0], target.shape[1], device=device) @@ -126,11 +139,6 @@ def run_inference_plot(model, args, sample_id, epoch, device): coords = all_coords.numpy()[0] mask = all_masks.numpy()[0] - # Un-log if necessary to get absolute counts - if getattr(args, "log_transform", False): - pathway_preds = np.expm1(pathway_preds) - pathway_truth = np.expm1(pathway_truth) - # 3. Filter Valid Spots if args.whole_slide: valid_idx = ~mask.astype(bool) @@ -145,9 +153,7 @@ def run_inference_plot(model, args, sample_id, epoch, device): # image. Let's fetch the raw coordinates directly from the .pt file. try: from spatial_transcript_former.data.paths import resolve_feature_dir - from spatial_transcript_former.recipes.hest.dataset import ( - load_gene_expression_matrix, - ) + from spatial_transcript_former.recipes.hest.dataset import get_h5ad_valid_mask feat_dir = resolve_feature_dir( args.data_dir, @@ -163,13 +169,7 @@ def run_inference_plot(model, args, sample_id, epoch, device): h5ad_path = os.path.join(st_dir, f"{sample_id}.h5ad") # We need the same valid mask used by the dataset - _, pt_mask, _ = load_gene_expression_matrix( - h5ad_path, - barcodes, - selected_gene_names=None, - num_genes=1, - ) - pt_mask_bool = np.array(pt_mask, dtype=bool) + pt_mask_bool = get_h5ad_valid_mask(h5ad_path, barcodes) if len(raw_coords[pt_mask_bool]) == len(coords): coords = raw_coords[pt_mask_bool] diff --git a/tests/data/test_pathways.py b/tests/data/test_pathways.py index 00055f3..7795675 100644 --- a/tests/data/test_pathways.py +++ b/tests/data/test_pathways.py @@ -207,6 +207,7 @@ def test_h5_roundtrip_morans(self, tmp_path): import h5py from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( load_pathway_activities, + PATHWAY_FILE_VERSION, ) n_spots, n_pathways = 50, 5 @@ -222,16 +223,18 @@ def test_h5_roundtrip_morans(self, tmp_path): f.create_dataset("barcodes", data=barcodes_bytes) f.create_dataset("pathway_names", data=pw_names) f.create_dataset("pathway_morans_i", data=morans_orig) + f.attrs["format_version"] = PATHWAY_FILE_VERSION _, _, _, morans_loaded = load_pathway_activities(h5_path, barcodes_raw) assert morans_loaded is not None np.testing.assert_array_almost_equal(morans_loaded, morans_orig) def test_h5_missing_morans_returns_none(self, tmp_path): - """Older H5 files without pathway_morans_i should return None.""" + """A current-version H5 without pathway_morans_i should return None.""" import h5py from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( load_pathway_activities, + PATHWAY_FILE_VERSION, ) n_spots, n_pathways = 20, 3 @@ -245,7 +248,129 @@ def test_h5_missing_morans_returns_none(self, tmp_path): f.create_dataset("activities", data=acts) f.create_dataset("barcodes", data=barcodes_bytes) f.create_dataset("pathway_names", data=pw_names) + f.attrs["format_version"] = PATHWAY_FILE_VERSION # No pathway_morans_i dataset _, _, _, morans_loaded = load_pathway_activities(h5_path, barcodes_raw) assert morans_loaded is None + + +# --------------------------------------------------------------------------- +# Score construction & file format +# --------------------------------------------------------------------------- + + +class TestPathwayScoring: + """Tests for `_score_pathways` semantics and on-disk file versioning.""" + + def test_score_is_simple_member_mean(self): + """activities[s, p] should equal mean of log1p CP10k expression of pathway p + members, computed on the input matrix as-is — no per-gene rescaling.""" + from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( + _score_pathways, + ) + + # 3 spots, 4 genes + expr = np.array( + [ + [1.0, 2.0, 3.0, 0.0], + [4.0, 0.0, 1.0, 2.0], + [0.0, 5.0, 2.0, 1.0], + ], + dtype=np.float32, + ) + gene_names = ["A", "B", "C", "D"] + # Two pathways. PW1 has 3 members so is scored; PW2 has 1 so is skipped. + pathway_dict = { + "PW1": ["A", "B", "C"], + "PW2": ["D"], + } + + activities, names, n_scored = _score_pathways( + expr, gene_names, pathway_dict, min_genes=3 + ) + + assert activities.shape == (3, 2) + assert names == ["PW1", "PW2"] + assert n_scored == 1 + + expected_pw1 = expr[:, [0, 1, 2]].mean(axis=1) + np.testing.assert_array_almost_equal(activities[:, 0], expected_pw1) + # PW2 below min_genes — left at zero + np.testing.assert_array_almost_equal(activities[:, 1], np.zeros(3)) + + def test_score_is_slide_stationary(self): + """Adding a per-slide constant to expression shifts the score by the same + constant — so a scoring-only z-score (which would centre this away) is + not happening.""" + from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( + _score_pathways, + ) + + np.random.seed(0) + expr = np.random.rand(8, 5).astype(np.float32) + gene_names = [f"G{i}" for i in range(5)] + pathway_dict = {"PW": ["G0", "G1", "G2"]} + + acts_a, _, _ = _score_pathways(expr, gene_names, pathway_dict, min_genes=3) + acts_b, _, _ = _score_pathways( + expr + 5.0, gene_names, pathway_dict, min_genes=3 + ) + + # Plain mean → constant-shift propagates exactly. (A z-scored variant + # would centre to ~0 for both, breaking this property.) + np.testing.assert_array_almost_equal(acts_b[:, 0] - acts_a[:, 0], 5.0) + + def test_load_rejects_missing_version(self, tmp_path): + """A file with no format_version attribute must be rejected.""" + import h5py + from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( + load_pathway_activities, + ) + + n_spots, n_pathways = 5, 2 + h5_path = str(tmp_path / "no_version.h5") + with h5py.File(h5_path, "w") as f: + f.create_dataset( + "activities", + data=np.zeros((n_spots, n_pathways), dtype=np.float32), + ) + f.create_dataset( + "barcodes", + data=np.array([f"S{i}" for i in range(n_spots)], dtype="S"), + ) + f.create_dataset( + "pathway_names", + data=np.array([f"P{i}" for i in range(n_pathways)], dtype="S"), + ) + # Intentionally omit f.attrs["format_version"] + + with pytest.raises(ValueError, match="format_version"): + load_pathway_activities(h5_path, [f"S{i}" for i in range(n_spots)]) + + def test_load_rejects_old_version(self, tmp_path): + """A file with format_version=1 (old z-scored layout) must be rejected.""" + import h5py + from spatial_transcript_former.recipes.hest.compute_pathway_activities import ( + load_pathway_activities, + ) + + n_spots, n_pathways = 5, 2 + h5_path = str(tmp_path / "v1.h5") + with h5py.File(h5_path, "w") as f: + f.create_dataset( + "activities", + data=np.zeros((n_spots, n_pathways), dtype=np.float32), + ) + f.create_dataset( + "barcodes", + data=np.array([f"S{i}" for i in range(n_spots)], dtype="S"), + ) + f.create_dataset( + "pathway_names", + data=np.array([f"P{i}" for i in range(n_pathways)], dtype="S"), + ) + f.attrs["format_version"] = 1 + + with pytest.raises(ValueError, match="format_version"): + load_pathway_activities(h5_path, [f"S{i}" for i in range(n_spots)]) diff --git a/tests/data/test_visualization.py b/tests/data/test_visualization.py index ef46ce8..f93a532 100644 --- a/tests/data/test_visualization.py +++ b/tests/data/test_visualization.py @@ -5,6 +5,7 @@ import os import tempfile +import h5py import pytest import numpy as np import matplotlib @@ -73,6 +74,101 @@ def mock_data(pathway_names): return coords, pathway_pred, pathway_truth +# --------------------------------------------------------------------------- +# _get_pathway_names +# --------------------------------------------------------------------------- + + +class TestGetPathwayNames: + """Tests for the three-tier pathway name resolution in visualization.py.""" + + def _make_args(self, pathways=None, pathway_targets_dir=None): + """Minimal args namespace.""" + import argparse + + args = argparse.Namespace() + args.pathways = pathways + args.pathway_targets_dir = pathway_targets_dir + return args + + def test_explicit_pathways_used_directly(self): + """When args.pathways matches num_expected, return it immediately.""" + from spatial_transcript_former.visualization import _get_pathway_names + + names = ["HALLMARK_WNT_BETA_CATENIN_SIGNALING", "HALLMARK_KRAS_SIGNALING_UP"] + args = self._make_args(pathways=names) + result = _get_pathway_names(args, num_expected=2) + assert result == names + + def test_explicit_pathways_length_mismatch_falls_through(self, tmp_path): + """If args.pathways length != num_expected, fall through to h5 lookup.""" + from spatial_transcript_former.visualization import _get_pathway_names + + # Write an h5 with 3 pathway names + h5_file = tmp_path / "sample.h5" + pw_names = [b"HALLMARK_A", b"HALLMARK_B", b"HALLMARK_C"] + with h5py.File(h5_file, "w") as f: + f.create_dataset("pathway_names", data=pw_names) + + # args.pathways has 2 names but num_expected is 3 — mismatch + args = self._make_args( + pathways=["HALLMARK_X", "HALLMARK_Y"], + pathway_targets_dir=str(tmp_path), + ) + result = _get_pathway_names(args, num_expected=3) + assert result == ["HALLMARK_A", "HALLMARK_B", "HALLMARK_C"] + + def test_h5_names_loaded_from_targets_dir(self, tmp_path): + """When args.pathways is None, names are read from the .h5 file.""" + from spatial_transcript_former.visualization import _get_pathway_names + + pw_names = [b"HALLMARK_HYPOXIA", b"HALLMARK_APOPTOSIS", b"HALLMARK_P53_PATHWAY"] + h5_file = tmp_path / "TENX001.h5" + with h5py.File(h5_file, "w") as f: + f.create_dataset("pathway_names", data=pw_names) + + args = self._make_args(pathway_targets_dir=str(tmp_path)) + result = _get_pathway_names(args, num_expected=3) + assert result == [ + "HALLMARK_HYPOXIA", + "HALLMARK_APOPTOSIS", + "HALLMARK_P53_PATHWAY", + ] + + def test_h5_names_decoded_from_bytes(self, tmp_path): + """Byte-string pathway names in .h5 should be decoded to str.""" + from spatial_transcript_former.visualization import _get_pathway_names + + h5_file = tmp_path / "sample.h5" + with h5py.File(h5_file, "w") as f: + f.create_dataset("pathway_names", data=[b"PATHWAY_A", b"PATHWAY_B"]) + + args = self._make_args(pathway_targets_dir=str(tmp_path)) + result = _get_pathway_names(args, num_expected=2) + assert all(isinstance(n, str) for n in result) + + def test_fallback_to_generic_names(self): + """When no pathways and no h5 dir, return generic Pathway_i labels.""" + from spatial_transcript_former.visualization import _get_pathway_names + + args = self._make_args() + result = _get_pathway_names(args, num_expected=4) + assert result == ["Pathway_0", "Pathway_1", "Pathway_2", "Pathway_3"] + + def test_h5_length_mismatch_falls_to_generic(self, tmp_path): + """If h5 pathway count != num_expected, fall through to generic names.""" + from spatial_transcript_former.visualization import _get_pathway_names + + h5_file = tmp_path / "sample.h5" + with h5py.File(h5_file, "w") as f: + f.create_dataset("pathway_names", data=[b"HALLMARK_A", b"HALLMARK_B"]) + + # h5 has 2 but we expect 5 + args = self._make_args(pathway_targets_dir=str(tmp_path)) + result = _get_pathway_names(args, num_expected=5) + assert result == [f"Pathway_{i}" for i in range(5)] + + # --------------------------------------------------------------------------- # Pathway constants # --------------------------------------------------------------------------- diff --git a/tests/models/test_interactions.py b/tests/models/test_interactions.py index b0e58e7..9ee83fb 100644 --- a/tests/models/test_interactions.py +++ b/tests/models/test_interactions.py @@ -297,7 +297,7 @@ def test_engine_passes_coords_to_forward(): fake_coords = torch.randn(2, 5, 2) fake_mask = torch.zeros(2, 5).bool() - # Dataloader yielding (feats, genes, pathway_targets, coords, mask, pathway_morans) + # Dataloader yielding (feats, genes, pathway_targets, coords, mask) loader = [ ( torch.randn(2, 5, 512), @@ -305,7 +305,6 @@ def test_engine_passes_coords_to_forward(): torch.randn(2, 5, 50), fake_coords, fake_mask, - None, ) ] @@ -340,7 +339,6 @@ def test_engine_validate_passes_coords(): torch.randn(2, 5, 50), fake_coords, torch.zeros(2, 5).bool(), - None, ) ] @@ -360,21 +358,29 @@ def dummy_criterion(p, t, mask=None): ), "Validate engine passed wrong coordinate tensor!" -def test_spatial_encoder_normalization(): - """Verify LearnedSpatialEncoder handles extreme coords and centers them.""" - encoder = LearnedSpatialEncoder(64) - # Extreme coordinates: very far and very close - coords = torch.tensor([[[1000.0, 1000.0], [1000.1, 1000.1]]]) - normed = encoder._normalize_coords(coords) +def test_spatial_encoder_per_spot_independence(): + """LearnedSpatialEncoder output for spot i must depend only on coords[i]. - # Should be centered (mean 0) - assert torch.allclose(normed.mean(dim=1), torch.zeros(1, 2), atol=1e-5) - # Should be bounded by [-1, 1] - assert normed.abs().max() <= 1.0 + Slide-stationary normalisation is now done in the dataset, so the encoder + itself is per-spot stateless. Permuting other spots in the batch must not + change a given spot's embedding. + """ + encoder = LearnedSpatialEncoder(64).eval() - # Verify forward doesn't crash - out = encoder(coords) - assert out.shape == (1, 2, 64) + coords_a = torch.tensor([[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]]) + coords_b = torch.tensor( + [[[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]] + ) # same set, reordered + + with torch.no_grad(): + out_a = encoder(coords_a) + out_b = encoder(coords_b) + + # Spot 0 ((0,0)) is in the same position in both — embedding must match + assert torch.allclose(out_a[0, 0], out_b[0, 0], atol=1e-6) + # Spot (1,0) at index 1 in batch a, index 2 in batch b — same embedding + assert torch.allclose(out_a[0, 1], out_b[0, 2], atol=1e-6) + assert out_a.shape == (1, 3, 64) def test_interaction_mask_bits(): diff --git a/tests/recipes/hest/test_dataset.py b/tests/recipes/hest/test_dataset.py index 3a63cee..1097fc5 100644 --- a/tests/recipes/hest/test_dataset.py +++ b/tests/recipes/hest/test_dataset.py @@ -211,8 +211,8 @@ def test_hest_feature_dataset_neighborhood_dropout(): # Run multiple times to trigger the stochastic dropout dropout_occurred = False for _ in range(100): - # Batch: (feats, genes, pathways, coords, morans) - f, g, _, _, _ = ds[0] + # Batch: (feats, genes, pathways, coords) + f, g, _, _ = ds[0] assert g is None, "Genes should be None in pathway-only mode" # Center (index 0) should NEVER be zero assert not torch.all(f[0] == 0) diff --git a/tests/training/test_checkpoints.py b/tests/training/test_checkpoints.py index d8a08e8..4fbcb91 100644 --- a/tests/training/test_checkpoints.py +++ b/tests/training/test_checkpoints.py @@ -59,7 +59,7 @@ def test_save_load_preserves_weights(self, small_model, checkpoint_dir): None, None, # schedulers epoch=42, - best_val_loss=0.123, + best_val_metric=0.123, output_dir=checkpoint_dir, model_name="interaction", ) @@ -105,7 +105,7 @@ def test_save_load_preserves_scaler(self, small_model, checkpoint_dir): scaler, None, # schedulers epoch=10, - best_val_loss=0.5, + best_val_metric=0.5, output_dir=checkpoint_dir, model_name="interaction", ) @@ -133,13 +133,17 @@ def test_save_load_preserves_scaler(self, small_model, checkpoint_dir): assert scaler.state_dict() == fresh_scaler.state_dict() def test_no_checkpoint_starts_fresh(self, small_model, checkpoint_dir): - """Missing checkpoint should return epoch 0 and inf loss.""" + """Missing checkpoint should return epoch 0 and a sentinel best metric. + + The metric is now CCC (higher is better), so the sentinel for "no + checkpoint yet" is ``-inf`` rather than ``+inf``. + """ optimizer = optim.Adam(small_model.parameters(), lr=1e-4) start_epoch, best_val, loaded_schedulers = load_checkpoint( small_model, optimizer, None, None, checkpoint_dir, "nonexistent", "cpu" ) assert start_epoch == 0 - assert best_val == float("inf") + assert best_val == -float("inf") # --------------------------------------------------------------------------- diff --git a/tests/training/test_losses.py b/tests/training/test_losses.py index 7c66f63..693fa56 100644 --- a/tests/training/test_losses.py +++ b/tests/training/test_losses.py @@ -1,5 +1,6 @@ """ -Tests for loss functions: MaskedMSELoss, PCCLoss, and CompositeLoss. +Tests for loss functions: MaskedMSELoss, PCCLoss, CCCLoss, MaskedHuberLoss, +CLIPAlignmentLoss, and CompositeLoss. """ import pytest @@ -7,9 +8,12 @@ import torch.nn as nn from spatial_transcript_former.training.losses import ( + CCCLoss, + CLIPAlignmentLoss, + CompositeLoss, + MaskedHuberLoss, MaskedMSELoss, PCCLoss, - CompositeLoss, ) # --------------------------------------------------------------------------- @@ -233,88 +237,199 @@ def test_pcc_edge_cases(): # --------------------------------------------------------------------------- -# Pathway Weights (Moran's I weighting of MSE) +# MaskedHuberLoss # --------------------------------------------------------------------------- -class TestPathwayWeights: - """Tests for per-pathway Moran's I weighting in losses.""" +class TestMaskedHuberLoss: + def test_near_zero_matches_mse(self): + """For small residuals (|x| << delta), Huber ≈ 0.5 * MSE.""" + torch.manual_seed(0) + preds = torch.zeros(16, 50) + target = preds + 0.01 # tiny residuals, well within quadratic zone + huber = MaskedHuberLoss(delta=1.0)(preds, target) + mse = MaskedMSELoss()(preds, target) + # Huber = 0.5 * x^2, MSE = x^2 -> Huber ≈ 0.5 * MSE for small x + assert huber.item() == pytest.approx(0.5 * mse.item(), rel=1e-3) + + def test_large_residuals_sub_quadratic(self): + """For large residuals (|x| >> delta), Huber grows linearly, MSE quadratically.""" + preds = torch.zeros(8, 10) + target = preds + 100.0 # far beyond delta=1.0 + huber = MaskedHuberLoss(delta=1.0)(preds, target) + mse = MaskedMSELoss()(preds, target) + # Huber should be much smaller than MSE for large errors + assert huber.item() < mse.item() + + def test_mask_reduces_loss(self, tensors_3d): + """Masking padded positions should change the loss value.""" + preds, target, mask = tensors_3d + loss_no_mask = MaskedHuberLoss()(preds, target) + loss_masked = MaskedHuberLoss()(preds, target, mask=mask) + assert not torch.allclose(loss_no_mask, loss_masked) - def test_uniform_weights_match_unweighted(self, tensors_2d): - """Uniform weights should produce same loss as no weights.""" - preds, target = tensors_2d - G = preds.shape[1] - uniform = torch.ones(G) + def test_gradient_flow(self, tensors_3d): + """Gradients should flow through MaskedHuberLoss.""" + preds, target, mask = tensors_3d + preds = preds.clone().requires_grad_(True) + loss = MaskedHuberLoss()(preds, target, mask=mask) + loss.backward() + assert preds.grad is not None + assert preds.grad.shape == preds.shape - loss_no_w = MaskedMSELoss()(preds, target) - loss_uniform = MaskedMSELoss()(preds, target, pathway_weights=uniform) - assert torch.allclose(loss_no_w, loss_uniform, atol=1e-5) + def test_perfect_predictions_zero(self): + """Loss should be zero for identical preds and target.""" + x = torch.randn(8, 20) + loss = MaskedHuberLoss()(x, x) + assert loss.item() == pytest.approx(0.0, abs=1e-6) - def test_nonuniform_weights_change_loss(self, tensors_2d): - """Non-uniform weights should produce a different loss.""" - preds, target = tensors_2d - G = preds.shape[1] - weights = torch.rand(G) + 0.1 # avoid zeros - loss_no_w = MaskedMSELoss()(preds, target) - loss_w = MaskedMSELoss()(preds, target, pathway_weights=weights) - # Very unlikely to be identical with random weights - assert not torch.allclose(loss_no_w, loss_w, atol=1e-5) +# --------------------------------------------------------------------------- +# CCCLoss +# --------------------------------------------------------------------------- - def test_zero_weight_pathway_contributes_nothing(self): - """A pathway with weight 0 should contribute 0 to the loss.""" - B, G = 8, 4 - torch.manual_seed(42) - preds = torch.randn(B, G) - target = torch.randn(B, G) - # Weight pathway 0 at 0, rest at 1 - weights = torch.tensor([0.0, 1.0, 1.0, 1.0]) - loss_w = MaskedMSELoss()(preds, target, pathway_weights=weights) +class TestCCCLoss: + def test_perfect_predictions_zero(self): + """Perfect predictions should give CCC=1, loss=0.""" + x = torch.randn(50, 100) + loss = CCCLoss()(x, x) + assert loss.item() == pytest.approx(0.0, abs=1e-5) + + def test_offset_penalised_more_than_pcc(self): + """A constant offset prediction should have CCC loss > PCC loss.""" + torch.manual_seed(42) + x = torch.randn(50, 100) + # Shift predictions by a large constant — PCC is shift-invariant, CCC is not + y = x + 5.0 + pcc_loss = PCCLoss()(x, y) + ccc_loss = CCCLoss()(x, y) + # PCC should be ~0 (perfect correlation), CCC should be larger + assert pcc_loss.item() == pytest.approx(0.0, abs=1e-4) + assert ccc_loss.item() > pcc_loss.item() + 0.1 - # Reference: MSE on only pathways 1-3 - diff_sq = (preds[:, 1:] - target[:, 1:]) ** 2 - expected = diff_sq.mean() - assert torch.allclose(loss_w, expected, atol=1e-5) + def test_3d_with_mask(self, tensors_3d): + """CCCLoss should handle 3D inputs with mask.""" + preds, target, mask = tensors_3d + loss = CCCLoss()(preds, target, mask=mask) + assert loss.isfinite() + assert 0.0 <= loss.item() <= 2.0 - def test_gradient_flow_with_weights(self, tensors_2d): - """Gradients should flow correctly with pathway weights.""" + def test_gradient_flow(self, tensors_2d): + """Gradients should flow through CCCLoss.""" preds, target = tensors_2d preds = preds.clone().requires_grad_(True) - G = preds.shape[1] - weights = torch.rand(G) + 0.1 - - loss = MaskedMSELoss()(preds, target, pathway_weights=weights) + loss = CCCLoss()(preds, target) loss.backward() assert preds.grad is not None - assert preds.grad.shape == preds.shape - def test_3d_with_mask_and_weights(self, tensors_3d): - """Pathway weights should work with 3D inputs and masking.""" - preds, target, mask = tensors_3d - G = preds.shape[2] - weights = torch.rand(G) + 0.1 + def test_anticorrelation(self): + """Negated inputs should give loss > 1 (CCC is strongly negative).""" + x = torch.randn(50, 100) + loss = CCCLoss()(x, -x) + # CCC of x and -x is negative, so 1 - CCC > 1 + assert loss.item() > 1.0 + + +# --------------------------------------------------------------------------- +# CLIPAlignmentLoss +# --------------------------------------------------------------------------- - loss = MaskedMSELoss()(preds, target, mask=mask, pathway_weights=weights) + +class TestCLIPAlignmentLoss: + def test_batch_size_one_returns_zero(self): + """B=1 should return 0.0 without crashing.""" + preds = torch.randn(1, 50) + target = torch.randn(1, 50) + loss = CLIPAlignmentLoss()(preds, target) + assert loss.item() == pytest.approx(0.0, abs=1e-6) + assert loss.requires_grad + + def test_identical_batch_has_high_loss(self): + """If all predictions are identical, cross-entropy should be near log(B).""" + B, G = 16, 50 + # All samples predict the same vector — worst case for CLIP + preds = torch.ones(B, G) + target = torch.randn(B, G) + loss = CLIPAlignmentLoss(temperature=1.0)(preds, target) + # Uniform distribution over B classes → cross-entropy ≈ log(B) + import math + + assert loss.item() == pytest.approx(math.log(B), rel=0.05) + + def test_perfect_batch_has_low_loss(self): + """Perfect predictions (preds == target) should give near-zero loss.""" + torch.manual_seed(7) + x = torch.randn(8, 50) + loss = CLIPAlignmentLoss(temperature=0.07)(x, x) + # With identical embeddings the diagonal always wins → loss ≈ 0 + assert loss.item() < 0.05 + + def test_3d_input_averaged(self, tensors_3d): + """3D inputs should be averaged over spatial dim before CLIP loss.""" + preds, target, mask = tensors_3d + loss = CLIPAlignmentLoss()(preds, target, mask=mask) assert loss.isfinite() - def test_composite_weights_affect_mse_not_pcc(self, tensors_2d): - """CompositeLoss should pass weights to MSE only, not PCC.""" + def test_gradient_flow(self, tensors_2d): + """Gradients should flow through CLIPAlignmentLoss.""" preds, target = tensors_2d - G = preds.shape[1] - weights = torch.rand(G) + 0.1 + preds = preds.clone().requires_grad_(True) + loss = CLIPAlignmentLoss()(preds, target) + loss.backward() + assert preds.grad is not None - # Compute parts manually - mse_weighted = MaskedMSELoss()(preds, target, pathway_weights=weights) - pcc_unweighted = PCCLoss()(preds, target) - expected = mse_weighted + 1.0 * pcc_unweighted - actual = CompositeLoss(alpha=1.0)(preds, target, pathway_weights=weights) - assert torch.allclose(expected, actual, atol=1e-5) +# --------------------------------------------------------------------------- +# CompositeLoss — new variants +# --------------------------------------------------------------------------- + + +class TestCompositeLossVariants: + def test_mse_ccc_gradients_flow(self, tensors_3d): + """CompositeLoss(pcc_type='ccc') gradients flow through both terms.""" + preds, target, mask = tensors_3d + preds = preds.clone().requires_grad_(True) + loss = CompositeLoss(alpha=1.0, pcc_type="ccc")(preds, target, mask=mask) + loss.backward() + assert preds.grad is not None + padded_grad = preds.grad[0, 80:, :] + assert padded_grad.abs().sum() == 0.0 + + def test_mse_ccc_clip_all_three_terms(self, tensors_2d): + """CompositeLoss(pcc_type='ccc', clip_weight=0.5) should be > mse+ccc alone.""" + preds, target = tensors_2d + torch.manual_seed(0) + loss_no_clip = CompositeLoss(alpha=1.0, pcc_type="ccc", clip_weight=0.0)( + preds, target + ) + loss_with_clip = CompositeLoss(alpha=1.0, pcc_type="ccc", clip_weight=0.5)( + preds, target + ) + # CLIP term adds a positive value so combined loss should differ + assert not torch.allclose(loss_no_clip, loss_with_clip) + + def test_mse_ccc_clip_gradients_flow(self, tensors_2d): + """All three terms in mse+ccc+clip should contribute gradients.""" + preds, target = tensors_2d + preds = preds.clone().requires_grad_(True) + loss = CompositeLoss(alpha=1.0, pcc_type="ccc", clip_weight=0.5)(preds, target) + loss.backward() + assert preds.grad is not None - def test_composite_no_weights_unchanged(self, tensors_2d): - """CompositeLoss without weights should behave identically to before.""" + def test_mse_huber_gradients_flow(self, tensors_2d): + """CompositeLoss(mse_type='huber') should produce valid gradients.""" preds, target = tensors_2d - loss_none = CompositeLoss(alpha=1.0)(preds, target, pathway_weights=None) - loss_orig = CompositeLoss(alpha=1.0)(preds, target) - assert torch.allclose(loss_none, loss_orig, atol=1e-6) + preds = preds.clone().requires_grad_(True) + loss = CompositeLoss(alpha=1.0, mse_type="huber")(preds, target) + loss.backward() + assert preds.grad is not None + + def test_regression_mse_pcc_unchanged(self, tensors_2d): + """Default CompositeLoss (pcc_type='pcc') should still equal MSE + PCC.""" + preds, target = tensors_2d + mse_val = MaskedMSELoss()(preds, target) + pcc_val = PCCLoss()(preds, target) + expected = mse_val + 1.0 * pcc_val + actual = CompositeLoss(alpha=1.0)(preds, target) + assert torch.allclose(expected, actual, atol=1e-5)