diff --git a/.gitignore b/.gitignore index b1744e0..ed95289 100644 --- a/.gitignore +++ b/.gitignore @@ -262,18 +262,7 @@ hest_data/patches/TENX175.h5 hest_data/st/TENX175.h5ad hest_data/tissue_seg/TENX175_contours.geojson hest_data/wsis/TENX175.tif -.idea/.gitignore -.idea/csv-editor.xml -.idea/deployment.xml -.idea/jupyter-settings.xml -.idea/misc.xml -.idea/modules.xml -.idea/SpatialTranscriptFormer.iml -.idea/vcs.xml -.idea/inspectionProfiles/profiles_settings.xml -.idea/inspectionProfiles/Project_Default.xml -.idea/runConfigurations/STF_Compute_Pathways.xml -.idea/runConfigurations/STF_Train_PrimaryPathway.xml +.idea/ .gemini/settings.json .gemini/agents/literature-search.md .gemini/agents/test-triage.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..668bcce --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,52 @@ +# Changelog + +All notable changes to the SpatialTranscriptFormer project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +--- + +## [Unreleased] + +### Added +- Created `CHANGELOG.md` documenting project history, milestones, and design choices. +- Documented the role of Moran's I (diagnostic target validation and spatial representation collapse detection) in [PATHWAY_MAPPING.md](docs/PATHWAY_MAPPING.md) and [spatial_stats.py](src/spatial_transcript_former/data/spatial_stats.py). + +### Changed +- Refactored baseline models (`HE2RNA`, `ViT_ST` in [regression.py](src/spatial_transcript_former/models/regression.py)) to accept `num_pathways` instead of `num_genes` and directly regress pathway activities. +- Corrected console script entry points in [pyproject.toml](pyproject.toml) to map to `recipes/hest/` instead of `data/`. +- Updated [setup.ps1](setup.ps1) and [setup.sh](setup.sh) to suggest `stf-compute-pathways` instead of `stf-build-vocab`. +- Cleaned up parameter descriptions and docstrings in [dataset.py](src/spatial_transcript_former/recipes/hest/dataset.py), [trainer.py](src/spatial_transcript_former/training/trainer.py), and [checkpoint.py](src/spatial_transcript_former/checkpoint.py). +- Completely updated documentation files ([DATALOADER.md](docs/DATALOADER.md), [MODELS.md](docs/MODELS.md), [SC_BEST_PRACTICES.md](docs/SC_BEST_PRACTICES.md), [TRAINING_GUIDE.md](docs/TRAINING_GUIDE.md), [TESTING.md](docs/TESTING.md), [PRECOMPUTED_WORKFLOW.md](docs/PRECOMPUTED_WORKFLOW.md), [DATA_FORMAT.md](docs/DATA_FORMAT.md)) to reflect the pathway-exclusive paradigm and remove legacy gene-reconstruction references. + +### Removed +- Deleted obsolete gene vocabulary builder script `build_vocab.py`. +- Deleted obsolete gene availability analysis document [GENE_ANALYSIS.md](docs/GENE_ANALYSIS.md). + +--- + +## [0.2.0] - 2026-06 + +### Added +- Integrated multi-loss framework containing Concordance Correlation Coefficient (CCC), Huber loss, and CLIP-style contrastive loss to improve target convergence and model robustness. +- Added direct supervision head for pre-computed pathway targets, eliminating circular dependency issues from older auxiliary pathway loss architectures. +- Created public inference API and model wrapping framework. +- Introduced Moran's I diagnostics for Spatially Variable Gene (SVG) selection and spatial pattern evaluation. +- Added licensing disclaimers and specific attribution details for MSigDB Hallmark gene sets (CC BY 4.0), HEST-1k dataset, and third-party foundation models (CTransPath, Phikon). + +### Fixed +- Resolved `TypeError` in transformer encoder by placing `enable_nested_tensor=False` in PyTorch's `TransformerEncoder` constructor. +- Configured pytest warnings filter in `pyproject.toml` to suppress non-critical output noise (e.g. deprecations from third-party libraries). + +--- + +## [0.1.0] - 2026-03 + +### Added +- Initialized core package architecture, modules, test suite, and scripts. +- Implemented the quad-flow interaction system (early fusion of spatial transcriptomics and whole-slide histology features). +- Added `LocalPatchMixer` module (Scatter-Gather depthwise 2D convolutions) to introduce localized spatial inductive biases into slide spot processing. +- Added support for pre-computing histology feature extraction (e.g. using CTransPath) and building KD-Tree representations for spatial neighbor retrieval. +- Developed an interactive Matplotlib visualization widget to overlay predicted pathway activities on histology slide coordinates. +- Set up GitHub Actions CI workflow for automated testing. diff --git a/docs/DATALOADER.md b/docs/DATALOADER.md index 3ac3943..fee010b 100644 --- a/docs/DATALOADER.md +++ b/docs/DATALOADER.md @@ -1,59 +1,64 @@ # HEST Dataloader Documentation -The `SpatialTranscriptFormer` uses a custom PyTorch dataloader designed for memory-efficient loading of large-scale spatial transcriptomics datasets. +The `SpatialTranscriptFormer` uses custom PyTorch dataloaders designed for memory-efficient loading of large-scale spatial transcriptomics datasets. The framework supports two loading paths: loading raw histology patches or loading pre-extracted feature vectors. ## Core Implementation Details -The implementation is located in `src/spatial_transcript_former/data/dataset.py`. +The implementation is located in [dataset.py](../src/spatial_transcript_former/recipes/hest/dataset.py). -### 1. `HEST_Dataset` Class +### 1. Raw-Patch Loading Path -This class implements the standard `torch.utils.data.Dataset` interface. +This path is used when training or evaluating directly on pixel-space images. -- **Lazy Loading**: To avoid overwhelming memory, it uses lazy loading for H5 file handles. File objects are initialized only when the first item is requested (typically within a worker process). -- **Indexing**: It supports an optional `indices` map, which allows it to represent a subset of the original data (e.g., after filtering for valid ST spots) without duplicating arrays in memory. -- **Transformation**: Images are permuted from `(H, W, C)` to `(C, H, W)` and normalized to `[0, 1]`. +* **`HEST_Dataset` Class**: Loads raw histology patches from a HEST `.h5` file. It supports: + * **Lazy File Access**: File handles are created lazily inside each worker process to avoid pickling issues during multiprocessing. + * **Neighbourhood Context**: Can retrieve a patch along with its $K$ nearest neighbours. + * **Dihedral Augmentation**: Randomly rotates or flips patch pixels and coordinates in sync. +* **`get_hest_dataloader`**: High-level orchestrator that creates a `DataLoader` over raw patches for a list of sample IDs, combining individual datasets using `ConcatDataset`. +* **Returned Tuples**: Yields `(patches, None, rel_coords)` where the second element (formerly gene expression counts) is `None`. -### 2. `load_gene_expression_matrix` +### 2. Pre-Computed Feature Loading Path -This utility function handles the complex process of aligning image patches to gene expression data. +This is the default path used by the SpatialTranscriptFormer training pipeline (`--precomputed`), as it avoids repeated backbone inference. -- **Barcode Alignment**: Since not every image patch in an `.h5` file necessarily has a corresponding transcriptomic profile in the `.h5ad` file, the function performs a lookup using the spot barcodes. -- **Gene Selection**: It can either: - 1. Select the top `N` most expressed genes from a single sample. - 2. Align the current sample to a predefined list of global gene names (filling missing genes with zeros). -- **Sparse Support**: It handles both dense and sparse (CSR) matrix formats in the `.h5ad` file. +* **`HEST_FeatureDataset` Class**: Loads pre-extracted feature vectors (e.g. CTransPath, Phikon) from `.pt` files and aligns them to pre-computed pathway activity targets from `.h5` files. + * **Spot barcode alignment**: Filters features to keep only spots that passed quality control (QC) in the corresponding `.h5ad` file. + * **Stationary Coordinate Normalisation**: Normalises coordinates relative to the slide's centroid and standard deviation so coordinates are invariant to batching. + * **Patch Mode**: Returns a single spot feature vector, its local neighbourhood features (optionally with random dropout augmentation), pre-computed pathway targets, and relative coordinates. + * **Whole-Slide Mode**: Returns all spots on the slide as a single sequence. +* **`get_hest_feature_dataloader`**: Builds a `DataLoader` over the feature datasets. + * In **patch mode**, yields standard batched tensors `(feats, None, pathway_acts, coords)`. + * In **whole-slide mode**, pads variable-length slides to the longest slide in the batch and appends a boolean padding mask. Yields `(padded_feats, None, padded_pathways, padded_coords, mask)`. -### 3. `get_hest_dataloader` +--- -The high-level orchestrator that creates a unified dataloader for multiple samples. - -- **Sample Concatenation**: It iterates through multiple sample IDs and creates individual `HEST_Dataset` instances, which are then combined using `torch.utils.data.ConcatDataset`. -- **Global Gene Lock**: The first sample found "locks" the gene list (usually the top 1000 genes). Every subsequent sample in the loop is then aligned to this specific set of genes to ensure consistent input dimensions for the model. - -## Usage Example +## Usage Example (Pre-Computed Features) ```python -from spatial_transcript_former.data import get_hest_dataloader +from spatial_transcript_former.recipes.hest.dataset import get_hest_feature_dataloader -# IDs from your metadata split +# Pre-selected training sample IDs train_ids = ['MEND29', 'TENX156', ...] -dataloader = get_hest_dataloader( - root_dir="A:/hest_data", +dataloader = get_hest_feature_dataloader( + root_dir="./hest_data", ids=train_ids, batch_size=32, shuffle=True, num_workers=4, - num_genes=1000 + n_neighbors=6, + pathway_targets_dir="./hest_data/pathway_activities" ) -for patches, gene_counts in dataloader: - # patches shape: (BS, 3, 224, 224) - # gene_counts shape: (BS, 1000) +for feats, _, pathway_acts, rel_coords in dataloader: + # feats shape: (BS, 1 + n_neighbors, feature_dim) + # pathway_acts shape: (BS, num_pathways) + # rel_coords shape: (BS, 1 + n_neighbors, 2) ... ``` -## Stratified Splitting +--- + +## Patient-Aware Stratified Splitting -For robust evaluation, we use `split_hest_patients` in `src/spatial_transcript_former/data/splitting.py`. This ensures that all samples from a single patient go into the same split (Train/Val/Test), preventing data leakage due to biological similarities between slides from the same donor. +To prevent data leakage due to biological similarities between multiple slides from the same donor, splits are stratified by patient. The splitting logic is located in [splitting.py](../src/spatial_transcript_former/recipes/hest/splitting.py) and exposed via the `stf-split` command. diff --git a/docs/DATA_FORMAT.md b/docs/DATA_FORMAT.md index f899d06..3aebc99 100644 --- a/docs/DATA_FORMAT.md +++ b/docs/DATA_FORMAT.md @@ -22,7 +22,10 @@ your_data_dir/ │ ├── sample1.pt │ ├── sample2.pt │ └── ... -└── global_genes.json # Generated by stf-build-vocab +└── pathway_activities/ # Pre-computed pathway targets + ├── sample1.h5 + ├── sample2.h5 + └── ... ``` *Note: If you are training from scratch on raw image crops, you would have a `patches/` directory with `.h5` files instead of the `he_features_ctranspath/` directory.* @@ -41,7 +44,7 @@ If you processed your data using standard tools like **10x SpaceRanger** and loa #### `var` (Variables / Genes) - Must contain an index representing the gene names (e.g., standard HGNC symbols like `TRAP1`, `BRCA1`). -- These names are used by `stf-build-vocab` to map your dataset to biological pathways (like MSigDB Hallmarks). +- These names are used by `stf-compute-pathways` to map your dataset to biological pathways (like MSigDB Hallmarks) during target pre-computation. #### `X` (Expression Matrix) @@ -75,7 +78,7 @@ adata.uns['spatial'] = { Once your files match the structure above: 1. **Place data:** Put your `.h5ad` files into `your_data_dir/st/`. -2. **Build Vocabulary:** Run `stf-build-vocab --data-dir your_data_dir/`. This will scan all your `.h5ad` files, find the most highly expressed genes, map them to biological pathways, and generate `global_genes.json`. +2. **Compute Pathway Targets:** Run `stf-compute-pathways --data-dir your_data_dir/`. This will process your `.h5ad` files, apply spot quality control, and pre-compute the Hallmark pathway activity target matrices saved to `your_data_dir/pathway_activities/`. 3. **Extract Features (Optional):** If you haven't already, run the feature extraction pipeline (e.g., `stf-extract`) to generate the `.pt` files in `he_features_ctranspath/`. 4. **Train:** You are now ready to run `stf-train`! diff --git a/docs/GENE_ANALYSIS.md b/docs/GENE_ANALYSIS.md deleted file mode 100644 index dd4ec1f..0000000 --- a/docs/GENE_ANALYSIS.md +++ /dev/null @@ -1,52 +0,0 @@ -# Gene Analysis and Modeling Strategies - -This document outlines the available gene sets for modeling in the Bowel Cancer dataset, based on an analysis of 84 Human Bowel samples across different technologies (Visium, Visium HD, and Xenium). - -## Gene Availability by Platform - -Biological analysis is constrained by the intersection of genes available across different spatial transcriptomics platforms. - -| Scope | Sample Count | Common Genes | Recommendation | -| :--- | :--- | :--- | :--- | -| **All Human Bowel** | 84 (Visium + Xenium) | ~405 | Cross-platform benchmarking | -| **Visium Only** | 78 | ~1060 | In-depth spatial profiling | - -### 1. The "Bowel Core" Gene Set (405 genes) - -This set represents the intersection of the Xenium panel and the Visium whole-transcriptome capture. - -- **Pros**: Allows the model to be trained on Visium data and evaluated on high-resolution Xenium data. -- **Cons**: Limited to a smaller subset of genes, which might miss important specific pathways. - -### 2. The "Visium Pan-Bowel" Gene Set (1060 genes) - -This set includes all genes present in every Visium sample in the HEST dataset. - -- **Pros**: Provides a much larger feature space (predicting 1000+ genes). -- **Cons**: Cannot be directly evaluated on Xenium samples without imputation or subsetting. - -## Implementation in the Dataloader - -The current dataloader implementation in `src/spatial_transcript_former/data/dataset.py` uses a "Gene Lock" mechanism: - -1. The first sample in the training loop determines the target gene list. -2. All subsequent samples are aligned to this list (missing genes are filled with zeros). - -### Recommended Strategy - -To ensure the best model stability, it is recommended to provide an explicit list of gene names to the `get_hest_dataloader` function instead of relying on the first sample's top genes. - -```python -# Create a fixed gene list for the project -bowel_genes = [...] # The 405 or 1060 common genes - -dataloader = get_hest_dataloader( - ids=sample_ids, - selected_gene_names=bowel_genes, - ... -) -``` - -## How to find specific genes? - -You can use the `inspection/analyze_gene_overlap.py` script to generate custom gene sets based on your specific sample filtering criteria. diff --git a/docs/MODELS.md b/docs/MODELS.md index b0fda0b..f9baa40 100644 --- a/docs/MODELS.md +++ b/docs/MODELS.md @@ -56,7 +56,7 @@ The spatial relationships of gene expression are central to this model. It is no 1. **Positional Encoding** — Each patch token receives a 2D sinusoidal encoding of its (x, y) coordinate on the tissue. This means the pathway tokens, when they attend to patches, can distinguish *where* each patch is. A pathway token can learn that EMT is localised at the tumour-stroma boundary, not uniformly across the slide. -2. **PCC Loss (Spatial Pattern Coherence)** — The Pearson Correlation component in the composite loss measures whether the *spatial pattern* of each gene's predicted expression matches the ground truth pattern, independently of scale. A model that predicts the same value everywhere scores PCC = 0, even if the mean is correct. This directly penalises spatial collapse. +2. **PCC Loss (Spatial Pattern Coherence)** — The Pearson Correlation component in the composite loss measures whether the *spatial pattern* of each pathway's predicted activity matches the ground truth pattern, independently of scale. A model that predicts the same value everywhere scores PCC = 0, even if the mean is correct. This directly penalises spatial collapse. Together, these ensure the model learns *spatially-varying* pathway activation maps rather than slide-level averages. @@ -157,8 +157,8 @@ where $\hat{h}_i$ and $\hat{p}_k$ are the L2-normalised patch and pathway tokens | Mode | Input | Output | Supervision | | :--- | :--- | :--- | :--- | -| **Dense (whole-slide)** | All patches from a slide | Per-patch gene predictions $(B, S, G)$ | Masked MSE+PCC at each spot | -| **Global** | All patches from a slide | Slide-level prediction $(B, G)$ | Mean-pooled expression | +| **Dense (whole-slide)** | All patches from a slide | Per-patch pathway predictions $(B, S, P)$ | Masked MSE+PCC at each spot | +| **Global** | All patches from a slide | Slide-level pathway prediction $(B, P)$ | Mean-pooled pathway activities | --- diff --git a/docs/PATHWAY_MAPPING.md b/docs/PATHWAY_MAPPING.md index 0c40cc7..11f64e8 100644 --- a/docs/PATHWAY_MAPPING.md +++ b/docs/PATHWAY_MAPPING.md @@ -59,6 +59,14 @@ attrs: These files are consumed at training time by `HEST_FeatureDataset` when `--pathway-targets-dir` is provided (which defaults to `/pathway_activities`). +### The Role of Moran's I in the Project + +Historically, Moran's I was introduced to rank and select Spatially Variable Genes (SVGs) when the model was trained on high-dimensional gene expression reconstruction. With the transition to the strictly pathway-exclusive architecture, its role has shifted: + +- **Why it is NOT used in the Loss**: Down-weighting pathways with low Moran's I during training was dropped because it is counterproductive. Crucial cancer pathways (e.g., Wnt/β-catenin) can exhibit low spatial autocorrelation across spots due to constitutive activation (from driver mutations like APC), yet remain key targets that the model must predict. +- **Why it is kept as a Diagnostic**: The pre-computed `pathway_morans_i` dataset in the `.h5` files acts as a slide-level spatial signature. It is used to curate disease-specific pathway sets (ensuring targets are above a background noise floor of ~0.15) and for offline biological analysis. +- **Role in validation (Collapse Detection)**: During validation, the training engine dynamically computes the Pearson correlation of predicted vs. ground-truth Moran's I vectors across the pathways. If the model suffers from spot-level representation collapse (predicting identical mean values everywhere), the predicted Moran's I drops to 0, which immediately registers as a drop in the validation `spatial_coherence` score. + ### Usage ```bash diff --git a/docs/PRECOMPUTED_WORKFLOW.md b/docs/PRECOMPUTED_WORKFLOW.md index ba03607..169607f 100644 --- a/docs/PRECOMPUTED_WORKFLOW.md +++ b/docs/PRECOMPUTED_WORKFLOW.md @@ -7,7 +7,7 @@ This workflow enables training the Spatial TranscriptFormer using pre-computed f Run the extraction script to process H&E patches and save feature tensors to `he_features/`. ```powershell -python src/spatial_transcript_former/data/extract_features.py --data-dir A:\hest_data --backbone resnet50 --batch-size 32 +stf-extract --data-dir hest_data --backbone resnet50 --batch-size 32 ``` **Arguments:** @@ -22,7 +22,7 @@ python src/spatial_transcript_former/data/extract_features.py --data-dir A:\hest Train the model using the `--precomputed` flag. The script will automatically filter for samples that have existing feature files. ```powershell -python src/spatial_transcript_former/train.py --data-dir A:\hest_data --model interaction --precomputed --epochs 50 --batch-size 32 --n-neighbors 6 +stf-train --data-dir hest_data --model interaction --precomputed --epochs 50 --batch-size 32 --n-neighbors 6 ``` **Key Arguments:** @@ -36,7 +36,7 @@ python src/spatial_transcript_former/train.py --data-dir A:\hest_data --model in To model long-range interactions (similar to Jaume et al.), use the `--use-global-context` flag. This will mix randomly sampled patches from the entire slide into the context window for each training sample. ```powershell -python src/spatial_transcript_former/train.py --data-dir A:\hest_data --model interaction --precomputed --use-global-context --global-context-size 256 +stf-train --data-dir hest_data --model interaction --precomputed --use-global-context --global-context-size 256 ``` - `--use-global-context`: Enables mixing of global patches. @@ -47,7 +47,7 @@ python src/spatial_transcript_former/train.py --data-dir A:\hest_data --model in To use **only global context** (and the center patch), simply set neighbors to 0: ```powershell -python src/spatial_transcript_former/train.py --data-dir A:\hest_data --model interaction --precomputed --use-global-context --global-context-size 256 --n-neighbors 0 +stf-train --data-dir hest_data --model interaction --precomputed --use-global-context --global-context-size 256 --n-neighbors 0 ``` This will construct a sequence of `[Center Patch] + [256 Random Global Patches]`. diff --git a/docs/SC_BEST_PRACTICES.md b/docs/SC_BEST_PRACTICES.md index 193e186..be1dc1f 100644 --- a/docs/SC_BEST_PRACTICES.md +++ b/docs/SC_BEST_PRACTICES.md @@ -9,27 +9,13 @@ Best-practices alignment and literature-driven improvement leads for SpatialTran These areas already follow industry best practices: -- **Global Gene Vocabulary** — `build_vocab.py` enforces a consistent feature space - across all samples, preventing feature mismatch at training and inference time. -- **SVG-aware Gene Selection** — Moran's I scoring is implemented in - `data/spatial_stats.py` and exposed via `stf-build-vocab --svg-weight`. Genes with - high spatial autocorrelation are the strongest learning targets for a spatially-aware - model. -- **Spatial Coherence Validation** — `spatial_coherence_score()` compares predicted - vs. ground-truth Moran's I for the top-50 SVGs, logging a Pearson correlation as a - validation metric alongside MSE/PCC. Computed automatically in `training/engine.py`. -- **Spatial Context via Neighbourhoods** — KD-tree-based neighbour aggregation in - `HEST_FeatureDataset` provides local spatial context to patch features. -- **Coordinate Standardisation** — `normalize_coordinates()` prevents spatial scale - bias between slides from different platforms (Visium, Visium HD, etc.). -- **Pathway-Aware Feature Selection** — MSigDB pathway gene prioritisation in the - vocabulary builder ensures biologically relevant signal even with a limited gene - budget. -- **Statistical Loss Modelling** — `ZINBLoss` accounts for the overdispersion and - sparsity inherent in raw count data. -- **Histology-Gene Integration** — The quad-flow architecture follows recommended - multi-modal integration patterns; pathway tokens act as a structured biological - bottleneck analogous to the Perceiver cross-attention design. +- **Global Pathway Targets** — Pre-computed using a standardized offline script (`stf-compute-pathways`) and stored as slide-stationary matrices, eliminating gene-to-pathway conversion latency during training. +- **Spatially Variable Pathway Targets** — Ground-truth targets represent curated MSigDB Hallmark pathway activity scores, which exhibit much higher spatial coherence than noisy individual gene expression values. +- **Spatial Coherence Validation** — Diagnostic Moran's I is computed for each pathway across slide spots to confirm predicted patterns retain biological spatial autocorrelation. +- **Spatial Context via Neighbourhoods** — KD-tree-based neighbour aggregation in `HEST_FeatureDataset` provides local spatial context to patch features. +- **Coordinate Standardisation** — `normalize_coordinates()` prevents spatial scale bias between slides from different platforms (Visium, Visium HD, etc.). +- **Direct Pathway Supervision** — Decoupled pathway target supervision eliminates circular auxiliary losses and noise from high-dimensional gene reconstruction. +- **Histology-Pathway Integration** — The quad-flow architecture follows recommended multi-modal integration patterns; pathway tokens act as a structured biological bottleneck analogous to the Perceiver cross-attention design. --- @@ -153,60 +139,30 @@ Status key: ✅ Implemented | 🔧 Open | 💡 Research lead ### Vocabulary & Preprocessing -**1. Mitochondrial gene filtering** 🔧 — High priority, low effort +**1. Mitochondrial gene filtering** ✅ — Completed -`global_genes_stats.csv` shows MT-CO3 and MT-CO2 in the top-20 most expressed genes. -Mitochondrial genes (`MT-*`) reflect cell health and apoptotic state, not spatial -morphological patterns, and are universally filtered in standard ST preprocessing. -Their presence inflates expression-rank scores for non-informative targets. +Mitochondrial genes (`MT-*`) reflect cell health and apoptotic state, not spatial morphological patterns. Standard spot-level QC filters out mitochondrial reads. +*Implementation*: The `stf-compute-pathways` pipeline filters out mitochondrial reads based on a configurable max MT fraction (`--qc-max-mt`, default: 0.15) during target pre-computation. -*Fix*: Exclude genes matching the `MT-` prefix before ranking in `build_vocab.py`. +**2. SVG-weighted vocab rebuild** 🚫 — Superseded -**2. SVG-weighted vocab rebuild** 🔧 — High priority, low effort +Superseded by the transition to a pathway-exclusive prediction paradigm, removing gene-level vocabulary alignment entirely. -The current `global_genes.json` was built with expression-only ranking — the -`global_genes_stats.csv` has no `morans_i` column, confirming `--svg-weight=0.0` was -used. SVG selection is now validated standard practice (SpatialDE, SPARK, NNSVG are -catalogued as de facto tools in the guidebook). +**3. Vectorise `morans_i` calculations** 🚫 — N/A -*Fix*: Re-run `stf-build-vocab --svg-weight 0.5 --svg-k 6` after applying the -MT-gene filter above. +Computing spatial autocorrelation over 50 pathway activities is computationally negligible, eliminating the need to vectorise high-dimensional gene-level Moran's I loops. -**3. Vectorise `morans_i_batch`** 🔧 — Medium priority, low effort +**4. Dispersion-based (HVG) gene filtering** 🚫 — Superseded -`spatial_stats.py:morans_i_batch` loops over G genes with individual Python calls. -For 38,839 genes across many samples this is the bottleneck for SVG-weighted runs. -One sparse matrix multiply replaces the loop: +Superseded by the transition to a pathway-exclusive prediction paradigm. -```python -z = expression - expression.mean(axis=0) # (N, G) -lag = W.dot(z) # (N, G) -num = (z * lag).sum(axis=0) # (G,) -den = (z**2).sum(axis=0) # (G,) -scores = np.where(den > 1e-12, (n / W_sum) * num / den, 0.0) -``` - -**4. Dispersion-based (HVG) gene filtering** 🔧 — Medium priority, medium effort - -Beyond total counts and spatial variability, filtering for **Highly Variable Genes** -using dispersion metrics (as in `sc.pp.highly_variable_genes`) focuses the model on -genes that carry biological variation between tissue states rather than static -structural signal. This is complementary to Moran's I: HVG captures across-sample -variation while Moran's I captures within-sample spatial structure. - -**5. Standardised library-size normalisation** 🔧 — Medium-High priority, medium effort +**5. Standardised library-size normalisation** ✅ — Completed -The pipeline currently lacks a standardised CPM/CP10k normalisation step before -`log1p`. Without it, sequencing depth variation between spots biases predictions -toward highly-sequenced spots. Standard: normalise to 10,000 counts per spot, then -`log1p`. +Spots are normalised to 10,000 counts (CP10k) before applying `log1p` during the offline pre-computation step. -**6. Per-spot quality control** 🔧 — Medium priority, medium effort +**6. Per-spot quality control** ✅ — Completed -Explicit QC thresholds (minimum UMI count, minimum detected genes, maximum -mitochondrial fraction) in the dataset loading scripts would protect the model from -training on low-quality "noise" spots. The mitochondrial fraction threshold directly -complements the MT-gene vocabulary filter above. +Explicit QC thresholds (minimum UMI count, minimum detected genes, maximum mitochondrial fraction) are enforced at target pre-computation time via `stf-compute-pathways`. --- @@ -335,20 +291,20 @@ would suffice. | # | Direction | Priority | Effort | Status | |---|-----------|----------|--------|--------| -| 1 | MT-gene filter in `build_vocab.py` | High | Low | 🔧 Open | -| 2 | Rebuild vocab with `--svg-weight 0.5` | High | Low | 🔧 Open | -| 3 | Vectorise `morans_i_batch` | Medium | Low | 🔧 Open | -| 4 | HVG dispersion-based gene filtering | Medium | Medium | 🔧 Open | -| 5 | Library-size normalisation (CP10k) | Medium-High | Medium | 🔧 Open | -| 6 | Per-spot QC thresholds | Medium | Medium | 🔧 Open | +| 1 | MT-gene filter in `stf-compute-pathways` | High | Low | ✅ Implemented | +| 2 | Rebuild vocab (Superseded by pathway targets) | High | Low | 🚫 Superseded | +| 3 | Vectorise `morans_i` calculations | Medium | Low | 🚫 N/A (Low Dim) | +| 4 | HVG dispersion-based gene filtering | Medium | Medium | 🚫 Superseded | +| 5 | Library-size normalisation (CP10k) | Medium-High | Medium | ✅ Implemented | +| 6 | Per-spot QC thresholds | Medium | Medium | ✅ Implemented | | 7 | Pre-compute pathway targets (decoupleR + PROGENy) | High | Medium | 💡 Lead | -| 8 | Moran's I weighted gene loss | High | Low | 🔧 Open | +| 8 | Moran's I weighted gene loss | High | Low | 🚫 Superseded | | 9 | PROGENy pathway token initialisation | High | Medium | 💡 Lead | | 10 | Cell–cell interaction tokens (LIANA+) | Low | High | 💡 Lead | | 11 | Cell-type deconvolution secondary head | Medium | High | 💡 Lead | | 12 | Nat. Comms. 2025 benchmark evaluation | High | Medium | 🔧 Open | | 13 | Scale to Visium HD / Xenium | Low | High | 💡 Lead | -| 14 | Preprocessing data contract doc | Low | Low | 🔧 Open | +| 14 | Preprocessing data contract doc | Low | Low | ✅ Implemented | --- diff --git a/docs/TESTING.md b/docs/TESTING.md index 78a5048..fc7d1f7 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -29,7 +29,7 @@ Tests are organized under `tests/` in subdirectories that reflect the source pac | Directory | Test Files | Coverage Area | | :--- | :--- | :--- | -| `tests/data/` | `test_data_integrity.py`, `test_pathways.py`, `test_augmentation.py`, `test_visualization.py` | Gene vocabulary, pathway scoring, data augmentation, visualization utilities | +| `tests/data/` | `test_data_integrity.py`, `test_pathways.py`, `test_augmentation.py`, `test_visualization.py` | Coordinate alignment, pathway scoring, data augmentation, visualization utilities | | `tests/models/` | `test_backbones.py`, `test_interactions.py`, `test_compilation.py` | Backbone loading, interaction model logic, `torch.compile` compatibility | | `tests/training/` | `test_losses.py`, `test_trainer.py`, `test_checkpoints.py`, `test_config.py` | Loss functions (MSE, PCC, composite), training loop, checkpoint serialization | | `tests/recipes/hest/` | HEST-specific dataset loading and split logic | HEST dataset and splitting | @@ -43,7 +43,7 @@ Tests for the offline pathway activity computation pipeline (`compute_pathway_ac - Per-spot QC filtering (min UMIs, min genes, max MT%) - CP10k normalisation correctness -- Z-scoring and mean pathway aggregation +- CP10k scaling and mean pathway aggregation - Moran I calculation - `.h5` output format and barcode alignment diff --git a/docs/TRAINING_GUIDE.md b/docs/TRAINING_GUIDE.md index 53e9f92..c6a71a9 100644 --- a/docs/TRAINING_GUIDE.md +++ b/docs/TRAINING_GUIDE.md @@ -13,7 +13,7 @@ conda activate SpatialTranscriptFormer ### Pre-Compute Pathway Activity Targets -Before any training run, you must pre-compute the pathway activity targets from raw expression data. This step applies per-spot QC, CP10k normalisation, and z-scoring. +Before any training run, you must pre-compute the pathway activity targets from raw expression data. This step applies per-spot QC, CP10k normalisation, and mean pathway aggregation. ```bash stf-compute-pathways --data-dir hest_data @@ -25,7 +25,7 @@ This will produce `.h5` files in `hest_data/pathway_activities/` which are consu ## 1. Single Patch Regression (Baselines) -Predicts gene expression for a single 224x224 patch. No cross attention interactions between patches or pathways. +Predicts pathway activity for a single 224x224 patch. No cross attention interactions between patches or pathways. ### HE2RNA (ResNet50) @@ -57,7 +57,7 @@ python -m spatial_transcript_former.train \ ## 2. Whole-Slide MIL (Multiple Instance Learning) -Aggregates all patches from a slide to predict the average expression. Recommended to use **precomputed features** from the `stf-compute-features` CLI tool for speed. Foundation models like ctranspath can be used as backbones. +Aggregates all patches from a slide to predict the average pathway activities. Recommended to use **precomputed features** from the `stf-compute-features` CLI tool for speed. Foundation models like ctranspath can be used as backbones. ### Attention MIL (Weak Supervision) diff --git a/pyproject.toml b/pyproject.toml index 30c6360..19c1a38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,11 +35,11 @@ dev = [ ] [project.scripts] -stf-download = "spatial_transcript_former.data.download:main" -stf-split = "spatial_transcript_former.data.splitting:main" +stf-download = "spatial_transcript_former.recipes.hest.download:main" +stf-split = "spatial_transcript_former.recipes.hest.splitting:main" stf-train = "spatial_transcript_former.train:main" stf-predict = "spatial_transcript_former.predict:main" -stf-extract = "spatial_transcript_former.data.extract_features:main" +stf-extract = "spatial_transcript_former.recipes.hest.extract_features:main" stf-compute-pathways = "spatial_transcript_former.recipes.hest.compute_pathway_activities:main" [tool.setuptools.packages.find] diff --git a/setup.ps1 b/setup.ps1 index 544e83c..06c9548 100644 --- a/setup.ps1 +++ b/setup.ps1 @@ -73,6 +73,6 @@ if ($HFNeedLogin) { Write-Host "You can then use the following commands:" Write-Host " stf-download --help" Write-Host " stf-split --help" -Write-Host " stf-build-vocab --help" +Write-Host " stf-compute-pathways --help" Write-Host "" Write-Host "To run tests, use: .\test.ps1" diff --git a/setup.sh b/setup.sh index e37fdd1..66567a0 100644 --- a/setup.sh +++ b/setup.sh @@ -65,6 +65,6 @@ fi echo "You can then use the following commands:" echo " stf-download --help" echo " stf-split --help" -echo " stf-build-vocab --help" +echo " stf-compute-pathways --help" echo "" echo "To run tests, use: ./test.sh" diff --git a/src/spatial_transcript_former/checkpoint.py b/src/spatial_transcript_former/checkpoint.py index 75a78f5..c424a62 100644 --- a/src/spatial_transcript_former/checkpoint.py +++ b/src/spatial_transcript_former/checkpoint.py @@ -4,7 +4,7 @@ Saves and loads a self-contained checkpoint directory containing: - config.json — architecture hyper-parameters - model.pth — model weights (state_dict) - - pathway_names.json — ordered list of gene symbols (optional) + - pathway_names.json — ordered list of pathway names (optional) """ import json @@ -98,7 +98,7 @@ def save_pretrained( model: A :class:`SpatialTranscriptFormer` instance. save_dir: Directory to write files into (created if needed). pathway_names: Optional ordered list of pathway names matching the - model's ``num_genes`` output dimension. + model's ``num_pathways`` output dimension. """ os.makedirs(save_dir, exist_ok=True) @@ -148,7 +148,7 @@ def load_pretrained( Returns: SpatialTranscriptFormer: The loaded model in eval mode with - ``gene_names`` attribute set (or ``None``). + ``pathway_names`` attribute set (or ``None``). """ from spatial_transcript_former.models.interaction import ( SpatialTranscriptFormer, diff --git a/src/spatial_transcript_former/data/spatial_stats.py b/src/spatial_transcript_former/data/spatial_stats.py index 8308925..d72f64d 100644 --- a/src/spatial_transcript_former/data/spatial_stats.py +++ b/src/spatial_transcript_former/data/spatial_stats.py @@ -1,15 +1,17 @@ """ -Spatial statistics utilities for gene selection. +Spatial statistics utilities for spatial transcriptomics. Provides lightweight, dependency-free Moran's I computation for -identifying spatially variable genes (SVGs) from spatial -transcriptomics data. - -Moran's I measures spatial autocorrelation: whether nearby spots tend -to have similar (positive I) or dissimilar (negative I) expression -for a given gene. Genes with high Moran's I show distinct spatial -patterns and are the strongest learning targets for -SpatialTranscriptFormer. +evaluating spatial autocorrelation of genes and biological pathways. + +Historically, Moran's I was used to rank and select Spatially Variable +Genes (SVGs) for reconstruction. In the strictly pathway-exclusive +architecture, Moran's I serves two critical roles: +1. As a pre-computed diagnostic to analyze spatial signatures and curate + pathway lists (ensuring targets are above background noise floors). +2. As a dynamic validation metric (via ``spatial_coherence_score``) + that compares predicted vs. ground-truth spatial patterns, acting + as an effective detector for spot-level representation collapse. """ import numpy as np @@ -50,7 +52,7 @@ def _build_knn_weights(coords: np.ndarray, k: int = 6) -> csr_matrix: def morans_i(x: np.ndarray, W: csr_matrix) -> float: - """Compute Moran's I for a single variable. + """Compute Moran's I for a single variable (e.g. pathway activity or gene expression). .. math:: @@ -59,7 +61,7 @@ def morans_i(x: np.ndarray, W: csr_matrix) -> float: {\\sum_i (x_i - \\bar{x})^2} Args: - x: (N,) array of values (e.g. gene expression per spot). + x: (N,) array of values (e.g. pathway activity score per spot). W: (N, N) sparse spatial weight matrix. Returns: @@ -126,20 +128,20 @@ def spatial_coherence_score( """Compare spatial structure of predictions vs ground truth. Computes Moran's I for both the predicted and ground-truth - expression matrices, then returns the Pearson correlation between - the two Moran's I vectors. A score near 1.0 means the model + matrices (genes or pathways), then returns the Pearson correlation + between the two Moran's I vectors. A score near 1.0 means the model reproduces the correct spatial patterns; near 0 means random. To keep computation fast (this runs every validation epoch), only - the ``top_k_genes`` with highest ground-truth spatial variability - are evaluated. + the ``top_k_genes`` (or pathways) with highest ground-truth + spatial variability are evaluated. Args: - predicted: (N, G) predicted expression matrix. - ground_truth: (N, G) ground-truth expression matrix. + predicted: (N, G) predicted expression or pathway activity matrix. + ground_truth: (N, G) ground-truth expression or pathway activity matrix. coords: (N, 2) spatial coordinates. k: KNN neighbours for the spatial weight graph. - top_k_genes: Number of top-Moran's-I genes to evaluate. + top_k_genes: Number of top-Moran's-I features to evaluate. Returns: Pearson correlation between predicted and ground-truth diff --git a/src/spatial_transcript_former/models/regression.py b/src/spatial_transcript_former/models/regression.py index a6e2b9e..696ff7f 100644 --- a/src/spatial_transcript_former/models/regression.py +++ b/src/spatial_transcript_former/models/regression.py @@ -2,7 +2,7 @@ Patch-level regression baselines for spatial transcriptomics. Each model takes a single histology patch (B, 3, H, W) and regresses a -gene-expression vector of length *num_genes*. These serve as lightweight +pathway-activity vector of length *num_pathways*. These serve as lightweight baselines relative to the full SpatialTranscriptFormer. """ @@ -11,10 +11,10 @@ class HE2RNA(nn.Module): - """ResNet-50 baseline that regresses gene expression from a single patch. + """ResNet-50 baseline that regresses pathway activity from a single patch. The backbone's classification head is replaced with a linear layer of size - *num_genes*. Weights come from the ``get_backbone`` factory, so any + *num_pathways*. Weights come from the ``get_backbone`` factory, so any supported backbone identifier can be supplied via *backbone*. Reference: @@ -22,11 +22,11 @@ class HE2RNA(nn.Module): expression of tumours from whole slide images." *Nature Communications*. """ - def __init__(self, num_genes, backbone="resnet50", pretrained=True): + def __init__(self, num_pathways, backbone="resnet50", pretrained=True): """Initialise HE2RNA with the chosen backbone and output size.""" super().__init__() self.backbone, self.feature_dim = get_backbone( - backbone, pretrained=pretrained, num_classes=num_genes + backbone, pretrained=pretrained, num_classes=num_pathways ) def forward(self, x): @@ -37,7 +37,7 @@ class ViT_ST(nn.Module): """Vision Transformer baseline for spatial transcriptomics regression. Adapts a ViT backbone (default ``vit_b_16``) to the ST task by replacing - its classification head with a linear layer of size *num_genes*. Any + its classification head with a linear layer of size *num_pathways*. Any backbone name accepted by ``get_backbone`` can be passed via *model_name*. Reference: @@ -45,11 +45,11 @@ class ViT_ST(nn.Module): Transformers for Image Recognition at Scale." *ICLR*. """ - def __init__(self, num_genes, model_name="vit_b_16", pretrained=True): + def __init__(self, num_pathways, model_name="vit_b_16", pretrained=True): """Initialise ViT_ST with the chosen backbone and output size.""" super().__init__() self.backbone, self.feature_dim = get_backbone( - model_name, pretrained=pretrained, num_classes=num_genes + model_name, pretrained=pretrained, num_classes=num_pathways ) def forward(self, x): diff --git a/src/spatial_transcript_former/recipes/hest/build_vocab.py b/src/spatial_transcript_former/recipes/hest/build_vocab.py deleted file mode 100644 index 4736f60..0000000 --- a/src/spatial_transcript_former/recipes/hest/build_vocab.py +++ /dev/null @@ -1,265 +0,0 @@ -import os -import argparse -import h5py -import numpy as np -import pandas as pd -from tqdm import tqdm -import json -import sys -from collections import defaultdict -from scipy.sparse import csr_matrix -from spatial_transcript_former.data.spatial_stats import morans_i_batch - -# Add src to path -sys.path.append(os.path.abspath("src")) -from spatial_transcript_former.recipes.hest.io import ( - get_hest_data_dir, - load_h5ad_metadata, -) -from spatial_transcript_former.config import get_config -from spatial_transcript_former.data.pathways import ( - download_msigdb_gmt, - parse_gmt, - MSIGDB_URLS, -) - - -def scan_h5ad_files(data_dir): - """ - Find all .h5ad files in the configured data directory's 'st' subfolder. - Works for both HEST and custom datasets. - """ - st_dir = os.path.join(data_dir, "st") - if not os.path.exists(st_dir): - print(f"Directory not found: {st_dir}") - print("Please ensure your data matches the structure in docs/DATA_FORMAT.md") - return [] - - sample_ids = [ - f.replace(".h5ad", "") for f in os.listdir(st_dir) if f.endswith(".h5ad") - ] - - print(f"Found {len(sample_ids)} .h5ad samples in {st_dir}.") - return sample_ids - - -def calculate_global_genes( - data_dir, - ids, - num_genes=1000, - target_pathways=None, - svg_weight=0.0, - svg_k=6, -): - st_dir = os.path.join(data_dir, "st") - if not ids: - print("No samples provided for calculation.") - return [], [] - - print(f"Scanning {len(ids)} samples in {st_dir}...") - if svg_weight > 0: - print(f"SVG mode: weight={svg_weight}, k={svg_k}") - - gene_totals = defaultdict(float) - # Moran's I accumulators (sum and count for averaging across samples) - gene_morans_sum = defaultdict(float) - gene_morans_count = defaultdict(int) - - for sample_id in tqdm(ids): - h5ad_path = os.path.join(st_dir, f"{sample_id}.h5ad") - - try: - # Use io utility for metadata - meta = load_h5ad_metadata(h5ad_path) - gene_names = meta["gene_names"] - - with h5py.File(h5ad_path, "r") as f: - X = f["X"] - if isinstance(X, h5py.Group): - data = X["data"][:] - indices = X["indices"][:] - indptr = X["indptr"][:] - - n_obs = len(meta["barcodes"]) - n_vars = len(gene_names) - mat = csr_matrix((data, indices, indptr), shape=(n_obs, n_vars)) - sums = np.array(mat.sum(axis=0)).flatten() - elif isinstance(X, h5py.Dataset): - mat = X[:] - sums = np.sum(mat, axis=0) - - for i, gene in enumerate(gene_names): - gene_totals[gene] += float(sums[i]) - - # --- SVG: compute Moran's I per gene for this sample --- - if svg_weight > 0 and "obsm" in f and "spatial" in f["obsm"]: - coords = f["obsm"]["spatial"][:] - # Densify the expression matrix for Moran's I - if isinstance(mat, csr_matrix): - dense_mat = mat.toarray() - else: - dense_mat = np.asarray(mat) - - mi_scores = morans_i_batch(dense_mat, coords, k=svg_k) - - for i, gene in enumerate(gene_names): - gene_morans_sum[gene] += mi_scores[i] - gene_morans_count[gene] += 1 - - except Exception as e: - print(f"Error processing {sample_id}: {e}") - - print(f"Aggregated counts for {len(gene_totals)} unique genes.") - if svg_weight > 0: - print(f"Computed Moran's I for {len(gene_morans_sum)} genes.") - - prioritized_genes = set() - if target_pathways: - print(f"Prioritizing genes from pathways: {target_pathways}") - - collections = ["hallmarks", "c2_kegg", "c2_medicus", "c2_cgp"] - combined_dict = {} - - for coll in collections: - url = MSIGDB_URLS[coll] - filename = url.split("/")[-1] - gmt_path = download_msigdb_gmt( - url, filename, os.path.join(data_dir, ".cache") - ) - combined_dict.update(parse_gmt(gmt_path)) - - for p in target_pathways: - if p in combined_dict: - for pw_gene in combined_dict[p]: - if pw_gene in gene_totals: - prioritized_genes.add(pw_gene) - else: - print(f"Warning: Pathway {p} not found in MSIGDB dictionaries.") - - print(f"Found {len(prioritized_genes)} valid target pathway genes.") - - # --- Ranking: expression-only or hybrid --- - all_genes = list(gene_totals.keys()) - - if svg_weight > 0 and gene_morans_sum: - # Compute average Moran's I per gene - gene_morans_avg = { - g: gene_morans_sum[g] / gene_morans_count[g] - for g in all_genes - if gene_morans_count.get(g, 0) > 0 - } - - # Rank by expression (lower rank = higher expression) - expr_sorted = sorted(all_genes, key=lambda g: gene_totals[g], reverse=True) - expr_rank = {g: r for r, g in enumerate(expr_sorted)} - - # Rank by Moran's I (lower rank = higher spatial variability) - mi_sorted = sorted( - all_genes, key=lambda g: gene_morans_avg.get(g, 0.0), reverse=True - ) - mi_rank = {g: r for r, g in enumerate(mi_sorted)} - - # Hybrid score: weighted sum of ranks (lower = better) - alpha = svg_weight - hybrid_score = { - g: (1 - alpha) * expr_rank[g] + alpha * mi_rank[g] for g in all_genes - } - sorted_all_genes = sorted(all_genes, key=lambda g: hybrid_score[g]) - - # Build stats list with Moran's I column - sorted_all = [ - (g, gene_totals[g], gene_morans_avg.get(g, 0.0)) for g in sorted_all_genes - ] - print( - f"Hybrid ranking: expression weight={(1 - alpha):.1f}, " - f"SVG weight={alpha:.1f}" - ) - else: - # Expression-only ranking (original behaviour) - sorted_all = sorted(gene_totals.items(), key=lambda x: x[1], reverse=True) - sorted_all_genes = [g for g, _ in sorted_all] - # Pad stats tuples with 0.0 Moran's I for consistent CSV format - sorted_all = [(g, c, 0.0) for g, c in sorted_all] - - top_genes = list(prioritized_genes) - for g in sorted_all_genes: - if len(top_genes) >= num_genes: - break - if g not in prioritized_genes: - top_genes.append(g) - - print( - f"Final set: {len(prioritized_genes)} pathway genes + " - f"{len(top_genes) - len(prioritized_genes)} global genes" - ) - - return top_genes, sorted_all - - -def main(): - parser = argparse.ArgumentParser( - description="Build Global Gene Vocabulary from .h5ad files" - ) - parser.add_argument( - "--data-dir", - type=str, - default=get_config("data_dirs", ["hest_data"])[0], - help="Root directory containing the 'st' subfolder", - ) - parser.add_argument( - "--num-genes", - type=int, - default=get_config("training.num_genes", 1000), - help="Maximum number of global genes to select", - ) - parser.add_argument( - "--pathways", - nargs="+", - default=None, - help="List of MSigDB pathway names to explicitly prioritize (e.g., HALLMARK_P53_PATHWAY)", - ) - parser.add_argument( - "--svg-weight", - type=float, - default=0.0, - help="Weight for spatial variability (Moran's I) in gene ranking. " - "0.0=expression-only (default), 1.0=SVG-only, 0.5=balanced.", - ) - parser.add_argument( - "--svg-k", - type=int, - default=6, - help="Number of KNN neighbours for spatial weight matrix (default: 6).", - ) - - args = parser.parse_args() - - # Output directly to the specified data directory - output_path = os.path.join(args.data_dir, "global_genes.json") - - ids = scan_h5ad_files(args.data_dir) - - if not ids: - print("Vocabulary builder aborted.") - sys.exit(1) - - top_genes, all_stats = calculate_global_genes( - args.data_dir, - ids, - args.num_genes, - target_pathways=args.pathways, - svg_weight=args.svg_weight, - svg_k=args.svg_k, - ) - - print(f"Saving top {len(top_genes)} genes to {output_path}") - with open(output_path, "w") as f: - json.dump(top_genes, f, indent=4) - - stats_df = pd.DataFrame(all_stats, columns=["gene", "total_counts", "morans_i"]) - stats_df.to_csv(output_path.replace(".json", "_stats.csv"), index=False) - print("Saved stats to CSV.") - - -if __name__ == "__main__": - main() diff --git a/src/spatial_transcript_former/recipes/hest/dataset.py b/src/spatial_transcript_former/recipes/hest/dataset.py index ba5c917..7328d32 100644 --- a/src/spatial_transcript_former/recipes/hest/dataset.py +++ b/src/spatial_transcript_former/recipes/hest/dataset.py @@ -241,7 +241,6 @@ def get_hest_dataloader( shuffle (bool): Whether to shuffle at each epoch. num_workers (int): DataLoader worker processes. transform (callable, optional): Transform applied to each patch tensor. - num_genes (int): Number of genes per sample. n_neighbors (int): Number of spatial neighbours to include per patch. ``0`` disables neighbourhood mode. augment (bool): Whether to apply dihedral augmentations. @@ -362,9 +361,6 @@ class HEST_FeatureDataset(SpatialDataset): Args: feature_path (str): Path to the ``.pt`` feature file. h5ad_path (str): Path to the corresponding ``.h5ad`` expression file. - num_genes (int): Number of genes expected in targets. - selected_gene_names (List[str], optional): Gene names to align targets - to. ``None`` uses discovery mode on the first loaded sample. n_neighbors (int): Spatial neighbours to include per patch. use_global_context (bool): Whether to append randomly sampled slide-level context patches to each neighbourhood sequence. diff --git a/src/spatial_transcript_former/training/builder.py b/src/spatial_transcript_former/training/builder.py index ddaba85..8329a43 100644 --- a/src/spatial_transcript_former/training/builder.py +++ b/src/spatial_transcript_former/training/builder.py @@ -27,13 +27,13 @@ def setup_model(args, device): if args.model == "he2rna": model = HE2RNA( - num_genes=args.num_pathways, + num_pathways=args.num_pathways, backbone=args.backbone, pretrained=args.pretrained, ) elif args.model == "vit_st": model = ViT_ST( - num_genes=args.num_pathways, + num_pathways=args.num_pathways, model_name=args.backbone if "vit_" in args.backbone else "vit_b_16", pretrained=args.pretrained, ) diff --git a/src/spatial_transcript_former/training/trainer.py b/src/spatial_transcript_former/training/trainer.py index 9a3bc70..ae63037 100644 --- a/src/spatial_transcript_former/training/trainer.py +++ b/src/spatial_transcript_former/training/trainer.py @@ -10,7 +10,7 @@ from spatial_transcript_former.training import Trainer from spatial_transcript_former.training.losses import CompositeLoss - model = SpatialTranscriptFormer(num_genes=460, backbone_name="phikon", ...) + model = SpatialTranscriptFormer(num_pathways=50, backbone_name="phikon", ...) trainer = Trainer( model=model, train_loader=train_dl, @@ -19,7 +19,7 @@ epochs=100, ) results = trainer.fit() - trainer.save_pretrained("./release/v1/", gene_names=my_genes) + trainer.save_pretrained("./release/v1/", pathway_names=my_pathways) """ import math @@ -380,7 +380,7 @@ def fit(self) -> Dict[str, Any]: # ------------------------------------------------------------------ def save_pretrained( - self, path: str, gene_names: Optional[List[str]] = None + self, path: str, pathway_names: Optional[List[str]] = None ) -> None: """Export an inference-ready checkpoint (strips optimizer state). @@ -390,4 +390,4 @@ def save_pretrained( save_pretrained as _save_pretrained, ) - _save_pretrained(self.model, path, gene_names=gene_names) + _save_pretrained(self.model, path, pathway_names=pathway_names) diff --git a/tests/models/test_backbones.py b/tests/models/test_backbones.py index 2ee1caa..f4d4f5e 100644 --- a/tests/models/test_backbones.py +++ b/tests/models/test_backbones.py @@ -64,7 +64,7 @@ def test_trans_mil_backbone(): def test_he2rna_backbone(): num_pathways = 50 - model = HE2RNA(num_genes=num_pathways, backbone="resnet50", pretrained=False) + model = HE2RNA(num_pathways=num_pathways, backbone="resnet50", pretrained=False) x = torch.randn(2, 3, 224, 224) out = model(x) assert out.shape == (2, num_pathways) @@ -73,7 +73,7 @@ def test_he2rna_backbone(): def test_vit_st_backbone(): num_pathways = 50 # Use resnet50 name to test fallback if vit_b_16 not available or just use vit_b_16 directly - model = ViT_ST(num_genes=num_pathways, model_name="resnet50", pretrained=False) + model = ViT_ST(num_pathways=num_pathways, model_name="resnet50", pretrained=False) x = torch.randn(2, 3, 224, 224) out = model(x) assert out.shape == (2, num_pathways)