Skip to content

saim-x/spin-road-mapper

Repository files navigation

spin-road-mapper

A from-scratch PyTorch implementation of SPIN Road Mapper, the graph-reasoning architecture for road segmentation from aerial images. Built on top of stacked hourglass convolutional networks with a plug-in SPIN module that reasons over two graph spaces to catch long-range road connectivity that convolutions miss.

We wrote every module ourselves, cross-checking against the reference repo and the paper. Three model variants, two datasets, 11 passing shape tests, and a clean ablation showing the graph reasoning contributes a measurable 5.2% IoU improvement over the identical network with SPIN stripped out.

Paper: Bandara et al. — SPIN Road Mapper: Extracting Roads from Aerial Images via Spatial and Interaction Space Graph Reasoning for Autonomous Driving (arXiv:2109.07701).


What this does

Convolutional networks see locally. Put enough layers on top and the receptive field grows, but it grows slowly. Roads are connected topographies — a segment on one side of an image can belong to the same route as a segment on the far side, with a patch of trees or a bridge in between. The network has no built-in way to link them.

SPIN addresses this by building graphs over feature maps in two spaces:

  • Spatial space: every position gets to talk to every other position through a learned adjacency matrix. Two road pixels that are far apart in the image can share information in a single graph convolution step.
  • Interaction space: features project into a learned latent space where semantic categories separate. Graph reasoning in this space helps the network tell roads apart from buildings, vegetation, and water — stuff that often looks similar in RGB.

The SPIN module adds about 480K parameters to a 29M-parameter hourglass network. That is 1.6% more parameters for a 5.2% better IoU on our DeepGlobe overfit test.


Results (quick summary)

Variant IoU Pixel Acc Parameters
Baseline (hourglass only) 0.836 0.977 29.0M
SPIN DGCN 0.888 0.986 29.5M

These numbers are from an overfit test on 5 fixed DeepGlobe crops trained for 200 epochs — not a full-scale training run. The point was isolating the graph reasoning contribution: same data, same hyperparameters, same backbone. The only difference is whether SPIN modules are plugged in. The gap comes mostly from SPIN producing far fewer false positives (1,098 vs 7,121).


Project structure

spin-road-mapper/
├── configs/
│   └── config.json              # All hyperparameters, dataset paths, model config
├── models/
│   ├── backbone.py              # FeatureExtractor, BasicResnetBlock, DecoderBlock
│   ├── hourglass.py             # HourglassModuleMTL — recursive dual-path hourglass
│   ├── spin.py                  # SpatialGCN + full SPIN module (the core contribution)
│   ├── network.py               # 3 model variants: baseline, DGCN, SPIN pyramid
│   └── registry.py              # Model name → class lookup
├── datasets/
│   ├── road_dataset.py          # RoadDataset base, MassachusettsDataset, DeepGlobeDataset
│   ├── fast_dataset.py          # Lightweight dataset (segmentation only, no orientation)
│   ├── cached_dataset.py        # Dataset with pre-computed orientation maps
│   ├── affinity_utils.py        # Orientation/angle map generation via skeletonization
│   ├── graph_utils.py           # RDP simplification, segment-to-linestring conversion
│   ├── sknw.py                  # Skeleton-to-NetworkX graph builder (from yxdragon/sknw)
│   └── rdp.py                   # Ramer-Douglas-Peucker line simplification
├── utils/
│   ├── loss.py                  # mIoULoss, CrossEntropyLoss2d, MultiTaskLoss
│   ├── metrics.py               # IoU, F1, pixel accuracy, relaxed F1
│   └── utils.py                 # Seed, weight init, checkpointing, AverageMeter
├── preprocessing/
│   ├── preprocess.py            # Full pipeline: filter corrupted images, crop, split
│   └── create_crops.py          # Tile large images into 256x256 training crops
├── train.py                     # Main training script (configurable, supports overfit mode)
├── train_full.py                # Full training with FP16 mixed precision + gradient accumulation
├── train_ablation.py            # Fast ablation on DeepGlobe subset
├── final_train.py               # Overfit head-to-head: baseline vs SPIN DGCN
├── overfit_test.py              # Overfit sanity check with fixed crops
├── evaluate.py                  # Evaluation with all metrics
├── test_shape.py                # 11 shape validation tests (one per module)
├── visualize_batch.py           # Batch visualization with GT + angle overlay
├── precompute_vecmaps.py        # Pre-compute orientation maps to skip runtime bottleneck
├── download_deepglobe.py        # Kaggle downloader for DeepGlobe
├── download_massachusetts.py    # Web scraper for Massachusetts dataset
├── requirements.txt
└── research_paper.pdf           # Original SPIN Road Mapper paper

Three model variants

All three share the same hourglass backbone and feature extractor. The only structural difference is which SPIN modules are present and where.

Model Graph Reasoning Params When to use
StackHourglassNetMTL None 29,002,012 Baseline for ablation comparisons
StackHourglassNetMTL_DGCN SPIN at hourglass bottleneck (seg + angle paths) 29,482,908 Best trade-off; main variant
StackHourglassNetMTL_SPIN_PYRAMID 3-scale SPIN pyramid in decoder 29,271,692 More context at cost of instability

The pyramid model works but needs careful initialization. Three cascaded SPIN modules push activations through the roof unless you average the pyramid outputs instead of summing them. We did that and it stabilizes — but the DGCN variant is the one we ran experiments on.


How the SPIN module works

The module splits into two parallel branches that fuse at the end:

Spatial branch: Three depthwise-separable strided convolutions downsample the feature map by 8x. A SpatialGCN builds a graph over the compressed spatial positions, runs one round of message passing, and bilinear upsampling brings it back. The result gets gated (x * local + x) to preserve fine spatial detail.

Interaction branch: Two linear projections — phi expands channels and theta contracts them — map features into a compact latent coordinate space. A 1D convolution processes the interaction matrix, followed by Laplacian smoothing (identity residual on the adjacency) and a state update. The result gets re-projected back to spatial dimensions via another multiplication with theta. A residual gate wraps it: ReLU(x + Y).

Fusion: The spatial and interaction outputs concatenate channel-wise and pass through a 1x1 convolution to halve the channels back to the original dimension.

The key thing is the adjacency matrix stays small. SpatialGCN computes A = softmax(Q x V^T) over channels, not spatial positions — so the graph is C/2 x C/2 regardless of image resolution. For 128 input channels, that is a 64x64 matrix. The spatial branch downsamples by 8x before running the GCN, so a 256x256 feature map becomes 32x32 positions in the graph. This is why the whole thing fits on a consumer GPU.


Datasets

Dataset Images Resolution Format Split
Massachusetts Roads 1,155 (16 corrupted removed) 1500x1500 GeoTIFF 981 train / 174 val (85/15)
DeepGlobe 6,226 1024x1024 JPEG + PNG mask 5,603 train / 623 val (90/10)

Massachusetts covers urban, suburban, and rural terrain across the state. DeepGlobe has varied landscapes from Thailand, Indonesia, and India — farmland, dense cities, coastlines, mountains. We found DeepGlobe JPEGs load vastly faster than Massachusetts TIFFs, which matters when you are iterating on a training loop.

Augmentations per sample: random 256x256 crop, horizontal flip with 50% probability, random 90-degree rotation (0, 90, 180, or 270).


Setup

Requirements:

  • Python 3.10+
  • PyTorch 2.x with CUDA
  • NVIDIA GPU with at least 8 GB VRAM (tested on RTX 3060 12 GB)
git clone https://github.com/saim-x/spin-road-mapper.git
cd spin-road-mapper
pip install -r requirements.txt

Datasets:

# DeepGlobe (requires Kaggle API)
python download_deepglobe.py

# Massachusetts Roads
python download_massachusetts.py

Preprocessing:

python preprocessing/preprocess.py

This filters out corrupted images, generates random 256x256 crops, and writes train/val split files to data/processed/.


Training

Overfit sanity check (recommended first step)

Makes sure the architecture, data pipeline, and loss function all work together before you spend hours on real training:

python overfit_test.py

Trains on 5 fixed crops for 200 epochs. Both baseline and SPIN DGCN should hit IoU above 0.8.

Ablation comparison

Runs SPIN DGCN and the baseline head-to-head on identical data:

python final_train.py

Full training (DeepGlobe subset)

python train_ablation.py     # 400 train / 100 val, 50 epochs
python train_full.py          # Full dataset, RTX 3060 optimized

Shape tests

11 tests covering every module — input/output shapes, no NaNs:

python test_shape.py

Massachusetts Roads full training

# Pre-compute orientation maps first (this is the bottleneck)
python precompute_vecmaps.py --dataset massachusetts

# Then train
python train.py --dataset massachusetts --model spin_dgcn

Hardware adaptations

The original paper runs on an RTX 8000 with 48 GB VRAM. We have an RTX 3060 with 12 GB. Here is what made it work:

Parameter Paper (RTX 8000) Ours (RTX 3060)
Batch size 16 8 (physical)
Effective batch 16 ~16 (via gradient accumulation x2)
Precision FP32 FP16 (AMP — cuts VRAM ~50%)
Crop size 256x256 256x256
Workers 4 2–4

FP16 mixed precision was the real win. No measurable accuracy loss, half the memory. Gradient accumulation handles the rest. The 29M-parameter network with SPIN modules trains fine on a consumer card — the actual bottleneck is data loading and orientation map computation, not the model forward pass.


What orientation maps are (and why we skipped them for ablation)

The paper uses a second task: predicting per-pixel road orientation angles, quantized into 37 bins (36 angle buckets of 10 degrees plus background). The idea is that knowing which way a road is going helps the segmentation branch stay connected across gaps.

Generating these maps is expensive. For each image you skeletonize the binary mask, build a graph from the skeleton, simplify it with RDP, extract keypoints, and compute per-pixel angle vectors. That entire pipeline per image takes longer than the model's forward and backward pass combined. For our ablation experiments we dropped the orientation task and trained on segmentation only. The SPIN improvement still showed up, which suggests the graph reasoning does useful work even without the angle signal.

If you want orientation maps, either precompute them with precompute_vecmaps.py or use CachedRoadDataset which reads pre-saved maps. The skeletonization code lives in datasets/affinity_utils.py and datasets/sknw.py.


Evaluation metrics

evaluate.py computes:

  • IoU (Intersection over Union): standard per-class metric, road class is what we care about
  • Pixel accuracy: percentage of correctly classified pixels
  • F1 score: harmonic mean of precision and recall
  • Relaxed F1: F1 with a spatial tolerance buffer — a predicted road pixel counts as a match if it is within a few pixels of a ground truth road pixel

Key implementation notes

  • Multi-scale deep supervision: the network outputs predictions at 4 scales (hourglass stack 1, stack 2, decoder at half-res, final full-res). Loss gets computed at all 4 and summed. Ground truth labels resize to match each scale.
  • BCEWithLogitsLoss with pos_weight=10: road pixels make up 3–8% of each crop. Without the weight the model learns fast that "everything is background" gets a low loss. The weight penalizes missed roads ten times harder.
  • On-the-fly cropping: instead of pre-generating thousands of tiles, each epoch randomly crops from full images. More augmentation variety, less disk usage.
  • SPIN pyramid stabilization: three cascaded SPIN modules push logit values into the thousands (softmax overflow, NaN loss). Averaging the pyramid outputs instead of summing them fixes it.

Why the overfit-first approach

We trained on 5 fixed images to convergence before touching the full dataset. If the architecture, data pipeline, or loss function were broken, 5 images would not overfit and the loss would not drop. Both models converged to high IoU on the overfit test, which told us the code was correct. Only then did we scale up. This saved a lot of time — nothing worse than running a full training loop for hours before noticing a shape mismatch in layer 47.


What we would do differently

The orientation map pipeline (skeletonization + graph extraction + angle quantization) is the project's biggest engineering headache. It depends on scikit-image, networkx, numba, and a custom fork of sknw. The per-image compute time dominates everything. If starting over, I would either precompute orientation maps once and cache them aggressively, or drop the orientation task entirely and focus on segmentation-only SPIN — which already shows the improvement.

The SPIN pyramid model fires on all cylinders structurally but three cascaded graph reasoning modules stacked without intermediate normalization cause activation explosion. Averaging outputs works but layer norm between pyramid levels would be cleaner.

Cross-dataset generalization (train on Massachusetts, test on DeepGlobe, or vice versa) is the next experiment worth running. It would answer whether SPIN learns general road connectivity patterns or overfits to the training distribution.


References

  • Bandara, W. G. C., Valanarasu, J. M. J., & Patel, V. M. (2021). SPIN Road Mapper: Extracting Roads from Aerial Images via Spatial and Interaction Space Graph Reasoning for Autonomous Driving. arXiv:2109.07701.
  • Mnih, V. (2013). Machine Learning for Aerial Image Labeling. Ph.D. dissertation, University of Toronto.
  • Demir, I., et al. (2018). DeepGlobe 2018: A Challenge to Parse the Earth through Satellite Images. CVPR Workshops.
  • Kipf, T. N., & Welling, M. (2017). Semi-Supervised Classification with Graph Convolutional Networks. ICLR.
  • Chen, Y., et al. (2019). Graph-Based Global Reasoning Networks. CVPR.
  • Li, Y., & Gupta, A. (2018). Beyond Grids: Learning Graph Representations for Visual Recognition. NeurIPS.

About

PyTorch implementation of SPIN Road Mapper: road segmentation from aerial images using spatial and interaction space graph reasoning on stacked hourglass networks.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages