Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7eb4372
Added png support. Added folder in configs. Added
Jun 2, 2025
fc31d4e
finished Early Stopping
Jun 4, 2025
7d9aed6
forgot to stage one file. Early Stopping finished.
Jun 4, 2025
e29bb57
added copilot readme
Jun 4, 2025
13db21a
black. mypy.
Jun 4, 2025
62dfee5
get_transform is now without seed because transform itself is too
Jun 4, 2025
9d6d7dd
made everything ready for CI/CD
Jun 4, 2025
82def70
rerun CI/CD
Jun 4, 2025
c568b2f
Added different target metrics and save_top_k parameter to early stop…
Jun 5, 2025
44e38f9
Update src/deepforest_finetuning/training/_finetuning.py
L17L Jun 5, 2025
615cb41
Update src/deepforest_finetuning/training/_finetuning.py
L17L Jun 5, 2025
3e5ec93
Update README.md
L17L Jun 5, 2025
f04b37d
Update README.md
L17L Jun 5, 2025
ddb6cc9
Update README.md
L17L Jun 5, 2025
4f84281
Update README.md
L17L Jun 5, 2025
bb5a68c
eveyrthing fixed from PR review
Jun 5, 2025
0f115a5
unused import fix
Jun 5, 2025
0bb30ec
Merge branch 'main' into Early-Stopping-&-Folder-in-configs
josafatburmeister Aug 30, 2025
462f3f9
update Readme
josafatburmeister Aug 30, 2025
f1358be
clean up code
josafatburmeister Aug 30, 2025
9ded357
update code formatting
josafatburmeister Aug 30, 2025
6fe14aa
update code to latest deepforest changes
josafatburmeister Sep 2, 2025
08f2b17
update Readme
josafatburmeister Sep 2, 2025
f6db2fc
update Readme
josafatburmeister Sep 2, 2025
01d0d38
Merge branch 'main' into Early-Stopping-&-Folder-in-configs
josafatburmeister Sep 2, 2025
e853e34
fix pylint error
josafatburmeister Sep 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 223 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,228 @@
## Fine-Tuning DeepForest for Forest Tree Detection in High-Resolution UAV Imagery

A pipeline for fine-tuning the [DeepForest](https://deepforest.readthedocs.io/en/v1.5.0/) model for tree detection to specific target sites.
A Python package for fine-tuning the [DeepForest](https://github.com/weecology/DeepForest) model on custom data. DeepForest is a deep learning model for detecting trees in aerial RGB imagery. This package extends DeepForest by providing a workflow to fine-tune the model on your own datasets. Key features include:

- Data preprocessing for various input formats
Comment thread
L17L marked this conversation as resolved.
- Automatic label projection from 3D point clouds to 2D orthophotos
- Image rescaling and tiling
- Model fine-tuning with multiple random seeds for robust evaluation
- Prediction on new images with customizable tiling
- Evaluation metrics calculation (precision, recall, F1 score)

## Installation

### Using Conda

1. Make sure you have conda installed. If conda is installed, `conda --version` should output the conda version.

2. Clone this repository:
```bash
git clone https://github.com/yourusername/deepforest-finetuning.git
cd deepforest-finetuning
```

3. Create and activate a conda environment from the provided environment.yml file:
```bash
conda env create -f environment.yml
conda activate deepforest-env
```

4. Install the package in development mode:
```bash
pip install -e .
```

### Using pip

1. Make sure that Python3 and pip are installed.

2. Clone this repository:
```bash
git clone https://github.com/yourusername/deepforest-finetuning.git
cd deepforest-finetuning
```

3. Install the [pointtorch](https://ai4trees.github.io/pointtorch/v0.2.0/) package and its dependencies (`${TORCH}` should be replaced by the PyTorch version and `${CUDA}` by `cpu`, `cu126`, etc., depending on the PyTorch installation):
```bash
pip install torch torchvision --index-url https://download.pytorch.org/whl/${CUDA}
pip install torch-scatter torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
pip install pointtorch
```

4. Install the [DeepForest](https://deepforest.readthedocs.io/en/v1.5.0/getting_started/install.html) package:

```bash
pip install "git+https://github.com/weecology/DeepForest.git"
```

5. Install the package in development mode:
```bash
pip install -e .
```

### Using Docker

A Dockerfile is provided for containerized usage:

```bash
docker build -t deepforest-finetuning .
docker run --gpus all --rm -it -v /path/to/your/data:/workspace/data/ deepforest-finetuning
```

## Workflow

### 1. Data Preprocessing

The package supports multiple preprocessing steps:

#### a. Projecting Labels from Point Clouds

If you have 3D point cloud data with pointwise tree instance labels in addition to 2D aerial images, you can project them to 2D bounding boxes:

```bash
python scripts/preprocessing.py configs/preprocessing/project_point_cloud_labels.toml
```

Required configuration:
```toml
base_dir = "/path/to/your/data"
point_cloud_paths = ["pointcloud1.las", "pointcloud2.las"]
image_paths = ["image1.tif", "image2.tif"]
label_json_output_paths = ["labels1.json", "labels2.json"]
```

#### b. Filtering Labels

Filters labels using non-maximum suppression based on overlap and size:

```bash
python scripts/preprocessing.py configs/preprocessing/filter_labels.toml
```

Required configuration:
```toml
base_dir = "/path/to/your/data"
input_label_folder = "labels"
output_label_folder = "labels_filtered"
iou_threshold = 0.5
```

#### c. Image Rescaling

Rescale images and corresponding labels to different resolutions:

```bash
python scripts/preprocessing.py configs/preprocessing/rescale_images.toml
```

Required configuration:
```toml
base_dir = "/path/to/your/data"
# input_images can either be a list of individual file paths or string specifying a folder path
input_images = ["image1.tif", "image2.tif"]
# if no labels are available, input_label_folders can be left empty
input_label_folders = ["labels"]
# there must be one output folder for each target resolution
output_folders = ["rescaled_2_5_cm", "rescaled_5_cm"]
target_resolutions = [0.025, 0.05]
```

### 2. Model Fine-tuning

Fine-tune the DeepForest model on your custom dataset:

```bash
python scripts/finetuning.py configs/finetuning/finetuning_5_cm_manual_labeling_small.toml
```

Example configuration:
```toml
base_dir = "/path/to/your/data"
tmp_dir = "./tmp"
patch_size = 640
patch_overlap = 0.2
image_folder = "images"
train_annotation_files = ["annotations"]
test_annotation_files = ["test_annotations"]
epochs = 20
seeds = [0, 1, 2, 3, 4]
learning_rate = 0.0001
checkpoint_dir = "checkpoints"
early_stopping_patience = 2
save_top_k = 1
target_metric = "test_f1"
```

This will:
1. Split images into patches and load training and test datasets
3. Fine-tune the model for the specified number of epochs
4. Run with multiple random seeds for robust evaluation
5. Save checkpoints and logs

### 3. Making Predictions

Make predictions with the fine-tuned model:

```bash
python scripts/prediction.py configs/finetuning/predict_finetuned_5_cm.toml
```

Example configuration:
```toml
checkpoint_path = "/path/to/checkpoint.pt"
image_files = ["/path/to/image.tif"]
predict_tile = true
patch_size = 1000
patch_overlap = 0.2

[prediction_export]
output_folder = "/path/to/predictions"
output_file_name = "predictions.csv"
```

### 4. Evaluating Results

Evaluate predictions against ground truth:

```bash
python scripts/evaluate.py configs/evaluation/evaluate_finetuned_5_cm.toml
```

Example configuration:
```toml
prediction_file = "/path/to/predictions.csv"
label_file = "/path/to/ground_truth.csv"
iou_threshold = 0.4
output_file = "/path/to/evaluation_results.csv"
```

## Configuration Files

All workflows are configured using TOML files. Example configurations are provided in the `configs` folder.

## Other Features

### Multiple Random Seeds

To ensure robust evaluation, you can run fine-tuning with multiple random seeds:

```toml
seeds = [0, 1, 2, 3, 4]
```

This will train separate models with different weight initializations and report the average performance.

### Early Stopping and Model Checkpointing

To prevent overfitting, you can enable early stopping and control model checkpointing:

```toml
early_stopping_patience = 2 # Stop training if performance doesn't improve for this many epochs
save_top_k = 1 # Save the top k best models based on the target metric
target_metric = "test_f1" # Metric to monitor for early stopping and checkpointing
```

The `mode` (min/max) is automatically inferred from the metric name. Metrics containing "loss" use "min" mode (lower is better), all others use "max" mode (higher is better).

### How to Cite

Expand Down
22 changes: 22 additions & 0 deletions configs/tree-ai-2025/finetuning_5_cm_TreeAI_full.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
base_dir = "/mnt/daten/TreeAI/finetuning/for_finetuning/5cm"
tmp_dir = "./tmp"
patch_size = 640
patch_overlap = 0.1
image_folder = "images"
train_annotation_files = [
"annotations",
]
test_annotation_files = [
"test_annotations",
]
epochs = 20
seeds = [0, 1, 2, 3, 4]
learning_rate = 0.0001
checkpoint_dir = "checkpoints"
early_stopping_patience = 2 # Set to the number of epochs to wait before stopping, or remove to disable early stopping
save_top_k = 1 # Save the top k best models based on target_metric
target_metric = "test_f1" # Metric to monitor for early stopping and model checkpointing.
# Mode min/max is inferred from the metric name.

[prediction_export]
output_folder = "predictions"
6 changes: 5 additions & 1 deletion scripts/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def preprocessing_step(config_path: str, config_type: Type, script_function: Cal
preprocess_manually_corrected_labels,
ManuallyCorrectedLabelPreprocessingConfig,
),
("project_point_cloud_labels", project_point_cloud_labels, PointCloudLabelProjectionConfig),
(
"project_point_cloud_labels",
project_point_cloud_labels,
PointCloudLabelProjectionConfig,
),
("rescale_images", rescale_images, ImageRescalingConfig),
]
fire_dict = {}
Expand Down
7 changes: 5 additions & 2 deletions src/deepforest_finetuning/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ class TrainingConfig: # pylint: disable=too-many-instance-attributes
patch_overlap: float
learning_rate: float
tmp_dir: str
train_annotation_files: List[str]
test_annotation_files: List[str]
train_annotation_files: Union[List[str], str]
test_annotation_files: Union[List[str], str]
prediction_export: ExportConfig
checkpoint_dir: Optional[str] = None
iou_threshold: float = 0.5
Expand All @@ -96,6 +96,9 @@ class TrainingConfig: # pylint: disable=too-many-instance-attributes
precision: str = "16-mixed"
float32_matmul_precision: str = "medium"
log_dir: str = "./logs"
early_stopping_patience: Optional[int] = None
save_top_k: int = 1
target_metric: str = "test_f1"


@dataclasses.dataclass
Expand Down
21 changes: 17 additions & 4 deletions src/deepforest_finetuning/evaluation/_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@

from pathlib import Path
import warnings
from typing import Union
from typing import Dict, Union

from deepforest.evaluate import evaluate_boxes
import pandas as pd


def evaluate(
predictions: pd.DataFrame, annotations: pd.DataFrame, iou_threshold: float, output_file: Union[str, Path]
) -> None:
predictions: pd.DataFrame,
annotations: pd.DataFrame,
iou_threshold: float,
output_file: Union[str, Path],
) -> Dict[str, float]:
"""
Evaluates a model's predictions and stores the evaluation metrics as CSV file.

Expand All @@ -22,9 +25,12 @@ def evaluate(
iou_threshold: Threshold for the IoU between predicted and target bounding boxes at which predicted bounding
boxes are counted as true positives.
output_file: Path of the CSV file in which to store the evaluation metrics.

Returns:
Dictionary containing the evaluation metrics (precision, recall, f1).
"""

# ignore deprecated warnings from pandas raised by deepforest.IoU (line 113: iou_df = pd.concat(iou_df))
# ignore deprecated pandas warnings raised by deepforest.IoU (line 113: iou_df = pd.concat(iou_df))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
results = evaluate_boxes(
Expand All @@ -48,3 +54,10 @@ def evaluate(
Path(output_file).parent.mkdir(exist_ok=True, parents=True)

df.to_csv(output_file, index=False)

# Return metrics dictionary for use with Lightning logger
return {
"precision": results["precision"],
"recall": results["recall"],
"f1": results["f1"],
}
2 changes: 0 additions & 2 deletions src/deepforest_finetuning/prediction/_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,12 @@ def prediction(
if predict_tile:
pred = model.predict_tile(
image=tree_dataset[img_idx].astype(np.float32),
return_plot=False,
patch_size=patch_size,
patch_overlap=patch_overlap,
)
else:
pred = model.predict_image(
image=tree_dataset[img_idx].astype(np.float32),
return_plot=False,
)
image_name = tree_dataset.__getname__(img_idx)
pred["image_path"] = image_name
Expand Down
16 changes: 15 additions & 1 deletion src/deepforest_finetuning/prediction/_prediction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,21 @@ def __getitem__(self, idx: int) -> npt.NDArray:
Returns: Image data.
"""
img_path = self.image_files[idx]
image_array = np.array(imread(img_path))[:, :, :3].astype(np.uint8)
file_ext = Path(img_path).suffix.lower()

if file_ext in [".tif", ".tiff"]:
# Use tifffile for TIFF images
image_array = np.array(imread(img_path))
else:
# Use PIL for other image formats (PNG, JPG, etc.)
image = Image.open(img_path)
image_array = np.array(image)

# Ensure we only take the first 3 channels if there are more
if image_array.ndim >= 3 and image_array.shape[2] > 3:
image_array = image_array[:, :, :3]

image_array = image_array.astype(np.uint8)

if self.resize_images_to is not None:
image = Image.fromarray(image_array)
Expand Down
7 changes: 6 additions & 1 deletion src/deepforest_finetuning/preprocessing/_filter_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ def filter_bounding_boxed_with_size_based_nms(coco_json: Dict[str, Any], iou_thr
for annotation in coco_json["annotations"]:
bounding_box = annotation["bbox"]
bounding_boxes.append(
[bounding_box[0], bounding_box[1], bounding_box[0] + bounding_box[2], bounding_box[1] + bounding_box[3]]
[
bounding_box[0],
bounding_box[1],
bounding_box[0] + bounding_box[2],
bounding_box[1] + bounding_box[3],
]
)
bounding_box_sizes.append(bounding_box[2] * bounding_box[3])

Expand Down
Loading