Skip to content

Ariel0818/Geosphere-DETR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

5 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

GeoSphere-DETR

A depth-aware extension of RT-DETR (Real-Time DEtection TRansformer) that fuses RGB images with geometry/depth information for improved citrus fruit detection. Built on top of Ultralytics. The dataset using for this project can be found on citrus-fruit-rgbd

What's Changed

This repo modifies the Ultralytics RT-DETR pipeline in 10 files. Everything else is the original Ultralytics codebase, kept intact to preserve the training/validation infrastructure.

Model Architecture

Core model β€” ultralytics/nn/tasks.py

RTDETRDetectionModel is extended to optionally accept a geometry branch:

  • geo_branch β€” a small CNN (3Γ— Conv-BN-SiLU) that encodes the depth/geometry input into 32-channel features
  • gate_heads β€” per-scale learned gates that control how much geometry information is fused at each FPN level (P3/P4/P5)
  • geo_to_feat / concat_proj β€” 1Γ—1 projections for additive or concat fusion
  • Fusion mode is controlled entirely by YAML β€” setting geom.fusion: off disables the branch with zero overhead
  • apply_geom_dropout() β€” train-time geometry dropout that simulates missing/corrupted depth (random block masking + full blackout)
  • apply_geom_missing_budget() β€” test-time controlled corruption for robustness evaluation

Sphere head β€” ultralytics/nn/modules/head.py

Optional Sphere-Consistency head attached to RTDETRDecoder:

  • A small MLP per query that predicts a binary foreground score (sphere logit)
  • Enabled/disabled via sphere.enable in the model YAML
  • At inference, can optionally re-score detections: score += lambda * sigmoid(sphere_logit)

Loss β€” ultralytics/models/utils/loss.py

RTDETRDetectionLoss gains an optional sphere consistency loss (binary cross-entropy against Hungarian-matched foreground targets) when the sphere head is active.

Data pipeline

File Change
data/augment.py Format: splits 4-channel input (RGBD) into img (3ch) + geom (1ch) so all augmentations stay aligned
data/dataset.py collate_fn: handles depth (Bx1xHxW) and geom (BxCgxHxW) tensors in addition to standard keys
data/base.py Loads RGBD images as 4-channel when channels: 4 is set in the data YAML

Training / Validation

File Change
engine/trainer.py Passes geom=batch["geom"] into the model forward during training
engine/validator.py Concatenates depth channel to image and passes geom during validation
models/rtdetr/val.py RTDETRValidator passes geom through predict()

Configuration

All new features are controlled via the model YAML. A plain rtdetr-l.yaml without a geom: section runs as the original RT-DETR with no overhead.

# geom fusion block (omit entirely to use standard RT-DETR)
geom:
  alpha: 0.5
  fusion: gated_scale   # off | gated_scale | gated_add | concat
  share_gate_heads: false
  debug: true
  debug_every: 10

  # train-time geometry dropout (simulates sensor failure)
  dropout:
    enable: true
    p: 0.3        # probability of applying dropout
    p_full: 0.3   # probability of full blackout (given dropout applied)
    blocks: [1, 4]
    scale: [0.05, 0.35]

  # test-time controlled corruption (for ablation)
  missing:
    ratio: 0.0    # 0.0 = off; try 0.1 / 0.3 / 0.5 / 0.7
    seed: 0
    scale: [0.05, 0.35]

# sphere consistency head (optional)
sphere:
  enable: false
  gain: 0.0       # loss weight during training
  lambda: 0.0     # inference re-scoring strength

Data YAML

train: path/to/train/images
val:   path/to/val/images

channels: 4          # 4 = RGBD input
geom_mode: BN        # D (depth only) | B | N | BN (bump + normal)
nc: 2
names: [class_a, class_b]

Training

yolo task=detect mode=train \
  model=rtdetr-l-geom.yaml \
  pretrained=rtdetr-l.pt \
  data=your_data.yaml \
  imgsz=640 epochs=100 batch=8 device=0

To run the original RT-DETR without any geometry branch, use rtdetr-l.yaml as normal β€” nothing in this repo changes its behaviour.


Requirements

Same as Ultralytics:

torch>=2.0
torchvision

Acknowledgements

Releases

No releases published

Packages

 
 
 

Contributors