A depth-aware extension of RT-DETR (Real-Time DEtection TRansformer) that fuses RGB images with geometry/depth information for improved citrus fruit detection. Built on top of Ultralytics. The dataset using for this project can be found on citrus-fruit-rgbd
This repo modifies the Ultralytics RT-DETR pipeline in 10 files. Everything else is the original Ultralytics codebase, kept intact to preserve the training/validation infrastructure.
RTDETRDetectionModel is extended to optionally accept a geometry branch:
geo_branchβ a small CNN (3Γ Conv-BN-SiLU) that encodes the depth/geometry input into 32-channel featuresgate_headsβ per-scale learned gates that control how much geometry information is fused at each FPN level (P3/P4/P5)geo_to_feat/concat_projβ 1Γ1 projections for additive or concat fusion- Fusion mode is controlled entirely by YAML β setting
geom.fusion: offdisables the branch with zero overhead apply_geom_dropout()β train-time geometry dropout that simulates missing/corrupted depth (random block masking + full blackout)apply_geom_missing_budget()β test-time controlled corruption for robustness evaluation
Optional Sphere-Consistency head attached to RTDETRDecoder:
- A small MLP per query that predicts a binary foreground score (sphere logit)
- Enabled/disabled via
sphere.enablein the model YAML - At inference, can optionally re-score detections:
score += lambda * sigmoid(sphere_logit)
RTDETRDetectionLoss gains an optional sphere consistency loss (binary cross-entropy against Hungarian-matched foreground targets) when the sphere head is active.
| File | Change |
|---|---|
data/augment.py |
Format: splits 4-channel input (RGBD) into img (3ch) + geom (1ch) so all augmentations stay aligned |
data/dataset.py |
collate_fn: handles depth (Bx1xHxW) and geom (BxCgxHxW) tensors in addition to standard keys |
data/base.py |
Loads RGBD images as 4-channel when channels: 4 is set in the data YAML |
| File | Change |
|---|---|
engine/trainer.py |
Passes geom=batch["geom"] into the model forward during training |
engine/validator.py |
Concatenates depth channel to image and passes geom during validation |
models/rtdetr/val.py |
RTDETRValidator passes geom through predict() |
All new features are controlled via the model YAML. A plain rtdetr-l.yaml without a geom: section runs as the original RT-DETR with no overhead.
# geom fusion block (omit entirely to use standard RT-DETR)
geom:
alpha: 0.5
fusion: gated_scale # off | gated_scale | gated_add | concat
share_gate_heads: false
debug: true
debug_every: 10
# train-time geometry dropout (simulates sensor failure)
dropout:
enable: true
p: 0.3 # probability of applying dropout
p_full: 0.3 # probability of full blackout (given dropout applied)
blocks: [1, 4]
scale: [0.05, 0.35]
# test-time controlled corruption (for ablation)
missing:
ratio: 0.0 # 0.0 = off; try 0.1 / 0.3 / 0.5 / 0.7
seed: 0
scale: [0.05, 0.35]
# sphere consistency head (optional)
sphere:
enable: false
gain: 0.0 # loss weight during training
lambda: 0.0 # inference re-scoring strengthtrain: path/to/train/images
val: path/to/val/images
channels: 4 # 4 = RGBD input
geom_mode: BN # D (depth only) | B | N | BN (bump + normal)
nc: 2
names: [class_a, class_b]yolo task=detect mode=train \
model=rtdetr-l-geom.yaml \
pretrained=rtdetr-l.pt \
data=your_data.yaml \
imgsz=640 epochs=100 batch=8 device=0To run the original RT-DETR without any geometry branch, use rtdetr-l.yaml as normal β nothing in this repo changes its behaviour.
Same as Ultralytics:
torch>=2.0
torchvision
- RT-DETR β Zhao et al., 2023
- Ultralytics β base framework (AGPL-3.0)
