diff --git a/README.md b/README.md index 991a78f..77be1d4 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,228 @@ ## Fine-Tuning DeepForest for Forest Tree Detection in High-Resolution UAV Imagery -A pipeline for fine-tuning the [DeepForest](https://deepforest.readthedocs.io/en/v1.5.0/) model for tree detection to specific target sites. +A Python package for fine-tuning the [DeepForest](https://github.com/weecology/DeepForest) model on custom data. DeepForest is a deep learning model for detecting trees in aerial RGB imagery. This package extends DeepForest by providing a workflow to fine-tune the model on your own datasets. Key features include: + +- Data preprocessing for various input formats +- Automatic label projection from 3D point clouds to 2D orthophotos +- Image rescaling and tiling +- Model fine-tuning with multiple random seeds for robust evaluation +- Prediction on new images with customizable tiling +- Evaluation metrics calculation (precision, recall, F1 score) + +## Installation + +### Using Conda + +1. Make sure you have conda installed. If conda is installed, `conda --version` should output the conda version. + +2. Clone this repository: + ```bash + git clone https://github.com/yourusername/deepforest-finetuning.git + cd deepforest-finetuning + ``` + +3. Create and activate a conda environment from the provided environment.yml file: + ```bash + conda env create -f environment.yml + conda activate deepforest-env + ``` + +4. Install the package in development mode: + ```bash + pip install -e . + ``` + +### Using pip + +1. Make sure that Python3 and pip are installed. + +2. Clone this repository: + ```bash + git clone https://github.com/yourusername/deepforest-finetuning.git + cd deepforest-finetuning + ``` + +3. Install the [pointtorch](https://ai4trees.github.io/pointtorch/v0.2.0/) package and its dependencies (`${TORCH}` should be replaced by the PyTorch version and `${CUDA}` by `cpu`, `cu126`, etc., depending on the PyTorch installation): + ```bash + pip install torch torchvision --index-url https://download.pytorch.org/whl/${CUDA} + pip install torch-scatter torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html + pip install pointtorch + ``` + +4. Install the [DeepForest](https://deepforest.readthedocs.io/en/v1.5.0/getting_started/install.html) package: + + ```bash + pip install "git+https://github.com/weecology/DeepForest.git" + ``` + +5. Install the package in development mode: + ```bash + pip install -e . + ``` + +### Using Docker + +A Dockerfile is provided for containerized usage: + +```bash +docker build -t deepforest-finetuning . +docker run --gpus all --rm -it -v /path/to/your/data:/workspace/data/ deepforest-finetuning +``` + +## Workflow + +### 1. Data Preprocessing + +The package supports multiple preprocessing steps: + +#### a. Projecting Labels from Point Clouds + +If you have 3D point cloud data with pointwise tree instance labels in addition to 2D aerial images, you can project them to 2D bounding boxes: + +```bash +python scripts/preprocessing.py configs/preprocessing/project_point_cloud_labels.toml +``` + +Required configuration: +```toml +base_dir = "/path/to/your/data" +point_cloud_paths = ["pointcloud1.las", "pointcloud2.las"] +image_paths = ["image1.tif", "image2.tif"] +label_json_output_paths = ["labels1.json", "labels2.json"] +``` + +#### b. Filtering Labels + +Filters labels using non-maximum suppression based on overlap and size: + +```bash +python scripts/preprocessing.py configs/preprocessing/filter_labels.toml +``` + +Required configuration: +```toml +base_dir = "/path/to/your/data" +input_label_folder = "labels" +output_label_folder = "labels_filtered" +iou_threshold = 0.5 +``` + +#### c. Image Rescaling + +Rescale images and corresponding labels to different resolutions: + +```bash +python scripts/preprocessing.py configs/preprocessing/rescale_images.toml +``` + +Required configuration: +```toml +base_dir = "/path/to/your/data" +# input_images can either be a list of individual file paths or string specifying a folder path +input_images = ["image1.tif", "image2.tif"] +# if no labels are available, input_label_folders can be left empty +input_label_folders = ["labels"] +# there must be one output folder for each target resolution +output_folders = ["rescaled_2_5_cm", "rescaled_5_cm"] +target_resolutions = [0.025, 0.05] +``` + +### 2. Model Fine-tuning + +Fine-tune the DeepForest model on your custom dataset: + +```bash +python scripts/finetuning.py configs/finetuning/finetuning_5_cm_manual_labeling_small.toml +``` + +Example configuration: +```toml +base_dir = "/path/to/your/data" +tmp_dir = "./tmp" +patch_size = 640 +patch_overlap = 0.2 +image_folder = "images" +train_annotation_files = ["annotations"] +test_annotation_files = ["test_annotations"] +epochs = 20 +seeds = [0, 1, 2, 3, 4] +learning_rate = 0.0001 +checkpoint_dir = "checkpoints" +early_stopping_patience = 2 +save_top_k = 1 +target_metric = "test_f1" +``` + +This will: +1. Split images into patches and load training and test datasets +3. Fine-tune the model for the specified number of epochs +4. Run with multiple random seeds for robust evaluation +5. Save checkpoints and logs + +### 3. Making Predictions + +Make predictions with the fine-tuned model: + +```bash +python scripts/prediction.py configs/finetuning/predict_finetuned_5_cm.toml +``` + +Example configuration: +```toml +checkpoint_path = "/path/to/checkpoint.pt" +image_files = ["/path/to/image.tif"] +predict_tile = true +patch_size = 1000 +patch_overlap = 0.2 + +[prediction_export] +output_folder = "/path/to/predictions" +output_file_name = "predictions.csv" +``` + +### 4. Evaluating Results + +Evaluate predictions against ground truth: + +```bash +python scripts/evaluate.py configs/evaluation/evaluate_finetuned_5_cm.toml +``` + +Example configuration: +```toml +prediction_file = "/path/to/predictions.csv" +label_file = "/path/to/ground_truth.csv" +iou_threshold = 0.4 +output_file = "/path/to/evaluation_results.csv" +``` + +## Configuration Files + +All workflows are configured using TOML files. Example configurations are provided in the `configs` folder. + +## Other Features + +### Multiple Random Seeds + +To ensure robust evaluation, you can run fine-tuning with multiple random seeds: + +```toml +seeds = [0, 1, 2, 3, 4] +``` + +This will train separate models with different weight initializations and report the average performance. + +### Early Stopping and Model Checkpointing + +To prevent overfitting, you can enable early stopping and control model checkpointing: + +```toml +early_stopping_patience = 2 # Stop training if performance doesn't improve for this many epochs +save_top_k = 1 # Save the top k best models based on the target metric +target_metric = "test_f1" # Metric to monitor for early stopping and checkpointing +``` + +The `mode` (min/max) is automatically inferred from the metric name. Metrics containing "loss" use "min" mode (lower is better), all others use "max" mode (higher is better). ### How to Cite diff --git a/configs/baseline/evaluate_without_finetuning_10_cm.toml b/configs/3d-geoinfo-2025/baseline/evaluate_without_finetuning_10_cm.toml similarity index 100% rename from configs/baseline/evaluate_without_finetuning_10_cm.toml rename to configs/3d-geoinfo-2025/baseline/evaluate_without_finetuning_10_cm.toml diff --git a/configs/baseline/evaluate_without_finetuning_2_5_cm.toml b/configs/3d-geoinfo-2025/baseline/evaluate_without_finetuning_2_5_cm.toml similarity index 100% rename from configs/baseline/evaluate_without_finetuning_2_5_cm.toml rename to configs/3d-geoinfo-2025/baseline/evaluate_without_finetuning_2_5_cm.toml diff --git a/configs/baseline/evaluate_without_finetuning_5_cm.toml b/configs/3d-geoinfo-2025/baseline/evaluate_without_finetuning_5_cm.toml similarity index 100% rename from configs/baseline/evaluate_without_finetuning_5_cm.toml rename to configs/3d-geoinfo-2025/baseline/evaluate_without_finetuning_5_cm.toml diff --git a/configs/baseline/evaluate_without_finetuning_7_5_cm.toml b/configs/3d-geoinfo-2025/baseline/evaluate_without_finetuning_7_5_cm.toml similarity index 100% rename from configs/baseline/evaluate_without_finetuning_7_5_cm.toml rename to configs/3d-geoinfo-2025/baseline/evaluate_without_finetuning_7_5_cm.toml diff --git a/configs/baseline/predict_without_finetuning_10_cm.toml b/configs/3d-geoinfo-2025/baseline/predict_without_finetuning_10_cm.toml similarity index 100% rename from configs/baseline/predict_without_finetuning_10_cm.toml rename to configs/3d-geoinfo-2025/baseline/predict_without_finetuning_10_cm.toml diff --git a/configs/baseline/predict_without_finetuning_2_5_cm.toml b/configs/3d-geoinfo-2025/baseline/predict_without_finetuning_2_5_cm.toml similarity index 100% rename from configs/baseline/predict_without_finetuning_2_5_cm.toml rename to configs/3d-geoinfo-2025/baseline/predict_without_finetuning_2_5_cm.toml diff --git a/configs/baseline/predict_without_finetuning_5_cm.toml b/configs/3d-geoinfo-2025/baseline/predict_without_finetuning_5_cm.toml similarity index 100% rename from configs/baseline/predict_without_finetuning_5_cm.toml rename to configs/3d-geoinfo-2025/baseline/predict_without_finetuning_5_cm.toml diff --git a/configs/baseline/predict_without_finetuning_7_5_cm.toml b/configs/3d-geoinfo-2025/baseline/predict_without_finetuning_7_5_cm.toml similarity index 100% rename from configs/baseline/predict_without_finetuning_7_5_cm.toml rename to configs/3d-geoinfo-2025/baseline/predict_without_finetuning_7_5_cm.toml diff --git a/configs/finetuning/10_cm/finetuning_10_cm_automatic_labeling_ext.toml b/configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_automatic_labeling_ext.toml similarity index 100% rename from configs/finetuning/10_cm/finetuning_10_cm_automatic_labeling_ext.toml rename to configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_automatic_labeling_ext.toml diff --git a/configs/finetuning/10_cm/finetuning_10_cm_automatic_labeling_small.toml b/configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_automatic_labeling_small.toml similarity index 100% rename from configs/finetuning/10_cm/finetuning_10_cm_automatic_labeling_small.toml rename to configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_automatic_labeling_small.toml diff --git a/configs/finetuning/10_cm/finetuning_10_cm_manual_correction_ext.toml b/configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_manual_correction_ext.toml similarity index 100% rename from configs/finetuning/10_cm/finetuning_10_cm_manual_correction_ext.toml rename to configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_manual_correction_ext.toml diff --git a/configs/finetuning/10_cm/finetuning_10_cm_manual_correction_small.toml b/configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_manual_correction_small.toml similarity index 100% rename from configs/finetuning/10_cm/finetuning_10_cm_manual_correction_small.toml rename to configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_manual_correction_small.toml diff --git a/configs/finetuning/10_cm/finetuning_10_cm_manual_labeling_ext.toml b/configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_manual_labeling_ext.toml similarity index 100% rename from configs/finetuning/10_cm/finetuning_10_cm_manual_labeling_ext.toml rename to configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_manual_labeling_ext.toml diff --git a/configs/finetuning/10_cm/finetuning_10_cm_manual_labeling_small.toml b/configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_manual_labeling_small.toml similarity index 100% rename from configs/finetuning/10_cm/finetuning_10_cm_manual_labeling_small.toml rename to configs/3d-geoinfo-2025/finetuning/10_cm/finetuning_10_cm_manual_labeling_small.toml diff --git a/configs/finetuning/2_5_cm/finetuning_2_5_cm_automatic_labeling_ext.toml b/configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_automatic_labeling_ext.toml similarity index 100% rename from configs/finetuning/2_5_cm/finetuning_2_5_cm_automatic_labeling_ext.toml rename to configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_automatic_labeling_ext.toml diff --git a/configs/finetuning/2_5_cm/finetuning_2_5_cm_automatic_labeling_small.toml b/configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_automatic_labeling_small.toml similarity index 100% rename from configs/finetuning/2_5_cm/finetuning_2_5_cm_automatic_labeling_small.toml rename to configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_automatic_labeling_small.toml diff --git a/configs/finetuning/2_5_cm/finetuning_2_5_cm_manual_correction_ext.toml b/configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_manual_correction_ext.toml similarity index 100% rename from configs/finetuning/2_5_cm/finetuning_2_5_cm_manual_correction_ext.toml rename to configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_manual_correction_ext.toml diff --git a/configs/finetuning/2_5_cm/finetuning_2_5_cm_manual_correction_small.toml b/configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_manual_correction_small.toml similarity index 100% rename from configs/finetuning/2_5_cm/finetuning_2_5_cm_manual_correction_small.toml rename to configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_manual_correction_small.toml diff --git a/configs/finetuning/2_5_cm/finetuning_2_5_cm_manual_labeling_ext.toml b/configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_manual_labeling_ext.toml similarity index 100% rename from configs/finetuning/2_5_cm/finetuning_2_5_cm_manual_labeling_ext.toml rename to configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_manual_labeling_ext.toml diff --git a/configs/finetuning/2_5_cm/finetuning_2_5_cm_manual_labeling_small.toml b/configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_manual_labeling_small.toml similarity index 100% rename from configs/finetuning/2_5_cm/finetuning_2_5_cm_manual_labeling_small.toml rename to configs/3d-geoinfo-2025/finetuning/2_5_cm/finetuning_2_5_cm_manual_labeling_small.toml diff --git a/configs/finetuning/5_cm/finetuning_5_cm_automatic_labeling_ext.toml b/configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_automatic_labeling_ext.toml similarity index 100% rename from configs/finetuning/5_cm/finetuning_5_cm_automatic_labeling_ext.toml rename to configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_automatic_labeling_ext.toml diff --git a/configs/finetuning/5_cm/finetuning_5_cm_automatic_labeling_small.toml b/configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_automatic_labeling_small.toml similarity index 100% rename from configs/finetuning/5_cm/finetuning_5_cm_automatic_labeling_small.toml rename to configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_automatic_labeling_small.toml diff --git a/configs/finetuning/5_cm/finetuning_5_cm_manual_correction_ext.toml b/configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_manual_correction_ext.toml similarity index 100% rename from configs/finetuning/5_cm/finetuning_5_cm_manual_correction_ext.toml rename to configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_manual_correction_ext.toml diff --git a/configs/finetuning/5_cm/finetuning_5_cm_manual_correction_small.toml b/configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_manual_correction_small.toml similarity index 100% rename from configs/finetuning/5_cm/finetuning_5_cm_manual_correction_small.toml rename to configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_manual_correction_small.toml diff --git a/configs/finetuning/5_cm/finetuning_5_cm_manual_labeling_ext.toml b/configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_manual_labeling_ext.toml similarity index 100% rename from configs/finetuning/5_cm/finetuning_5_cm_manual_labeling_ext.toml rename to configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_manual_labeling_ext.toml diff --git a/configs/finetuning/5_cm/finetuning_5_cm_manual_labeling_small.toml b/configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_manual_labeling_small.toml similarity index 100% rename from configs/finetuning/5_cm/finetuning_5_cm_manual_labeling_small.toml rename to configs/3d-geoinfo-2025/finetuning/5_cm/finetuning_5_cm_manual_labeling_small.toml diff --git a/configs/finetuning/7_5_cm/finetuning_7_5_cm_automatic_labeling_ext.toml b/configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_automatic_labeling_ext.toml similarity index 100% rename from configs/finetuning/7_5_cm/finetuning_7_5_cm_automatic_labeling_ext.toml rename to configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_automatic_labeling_ext.toml diff --git a/configs/finetuning/7_5_cm/finetuning_7_5_cm_automatic_labeling_small.toml b/configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_automatic_labeling_small.toml similarity index 100% rename from configs/finetuning/7_5_cm/finetuning_7_5_cm_automatic_labeling_small.toml rename to configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_automatic_labeling_small.toml diff --git a/configs/finetuning/7_5_cm/finetuning_7_5_cm_manual_correction_ext.toml b/configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_manual_correction_ext.toml similarity index 100% rename from configs/finetuning/7_5_cm/finetuning_7_5_cm_manual_correction_ext.toml rename to configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_manual_correction_ext.toml diff --git a/configs/finetuning/7_5_cm/finetuning_7_5_cm_manual_correction_small.toml b/configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_manual_correction_small.toml similarity index 100% rename from configs/finetuning/7_5_cm/finetuning_7_5_cm_manual_correction_small.toml rename to configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_manual_correction_small.toml diff --git a/configs/finetuning/7_5_cm/finetuning_7_5_cm_manual_labeling_ext.toml b/configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_manual_labeling_ext.toml similarity index 100% rename from configs/finetuning/7_5_cm/finetuning_7_5_cm_manual_labeling_ext.toml rename to configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_manual_labeling_ext.toml diff --git a/configs/finetuning/7_5_cm/finetuning_7_5_cm_manual_labeling_small.toml b/configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_manual_labeling_small.toml similarity index 100% rename from configs/finetuning/7_5_cm/finetuning_7_5_cm_manual_labeling_small.toml rename to configs/3d-geoinfo-2025/finetuning/7_5_cm/finetuning_7_5_cm_manual_labeling_small.toml diff --git a/configs/finetuning/run_finetuning.bat b/configs/3d-geoinfo-2025/finetuning/run_finetuning.bat similarity index 100% rename from configs/finetuning/run_finetuning.bat rename to configs/3d-geoinfo-2025/finetuning/run_finetuning.bat diff --git a/configs/preprocessing/filter_labels.toml b/configs/3d-geoinfo-2025/preprocessing/filter_labels.toml similarity index 100% rename from configs/preprocessing/filter_labels.toml rename to configs/3d-geoinfo-2025/preprocessing/filter_labels.toml diff --git a/configs/preprocessing/preprocess_manually_corrected_labels.toml b/configs/3d-geoinfo-2025/preprocessing/preprocess_manually_corrected_labels.toml similarity index 100% rename from configs/preprocessing/preprocess_manually_corrected_labels.toml rename to configs/3d-geoinfo-2025/preprocessing/preprocess_manually_corrected_labels.toml diff --git a/configs/preprocessing/project_point_cloud_labels.toml b/configs/3d-geoinfo-2025/preprocessing/project_point_cloud_labels.toml similarity index 100% rename from configs/preprocessing/project_point_cloud_labels.toml rename to configs/3d-geoinfo-2025/preprocessing/project_point_cloud_labels.toml diff --git a/configs/preprocessing/rescale_images.toml b/configs/3d-geoinfo-2025/preprocessing/rescale_images.toml similarity index 100% rename from configs/preprocessing/rescale_images.toml rename to configs/3d-geoinfo-2025/preprocessing/rescale_images.toml diff --git a/configs/tree-ai-2025/finetuning_5_cm_TreeAI_full.toml b/configs/tree-ai-2025/finetuning_5_cm_TreeAI_full.toml new file mode 100644 index 0000000..873a55f --- /dev/null +++ b/configs/tree-ai-2025/finetuning_5_cm_TreeAI_full.toml @@ -0,0 +1,22 @@ +base_dir = "/mnt/daten/TreeAI/finetuning/for_finetuning/5cm" +tmp_dir = "./tmp" +patch_size = 640 +patch_overlap = 0.1 +image_folder = "images" +train_annotation_files = [ + "annotations", +] +test_annotation_files = [ + "test_annotations", +] +epochs = 20 +seeds = [0, 1, 2, 3, 4] +learning_rate = 0.0001 +checkpoint_dir = "checkpoints" +early_stopping_patience = 2 # Set to the number of epochs to wait before stopping, or remove to disable early stopping +save_top_k = 1 # Save the top k best models based on target_metric +target_metric = "test_f1" # Metric to monitor for early stopping and model checkpointing. +# Mode min/max is inferred from the metric name. + +[prediction_export] +output_folder = "predictions" diff --git a/scripts/preprocessing.py b/scripts/preprocessing.py index b4fe769..12575a4 100644 --- a/scripts/preprocessing.py +++ b/scripts/preprocessing.py @@ -43,7 +43,11 @@ def preprocessing_step(config_path: str, config_type: Type, script_function: Cal preprocess_manually_corrected_labels, ManuallyCorrectedLabelPreprocessingConfig, ), - ("project_point_cloud_labels", project_point_cloud_labels, PointCloudLabelProjectionConfig), + ( + "project_point_cloud_labels", + project_point_cloud_labels, + PointCloudLabelProjectionConfig, + ), ("rescale_images", rescale_images, ImageRescalingConfig), ] fire_dict = {} diff --git a/src/deepforest_finetuning/config/_config.py b/src/deepforest_finetuning/config/_config.py index 4061719..599c8ff 100644 --- a/src/deepforest_finetuning/config/_config.py +++ b/src/deepforest_finetuning/config/_config.py @@ -84,8 +84,8 @@ class TrainingConfig: # pylint: disable=too-many-instance-attributes patch_overlap: float learning_rate: float tmp_dir: str - train_annotation_files: List[str] - test_annotation_files: List[str] + train_annotation_files: Union[List[str], str] + test_annotation_files: Union[List[str], str] prediction_export: ExportConfig checkpoint_dir: Optional[str] = None iou_threshold: float = 0.5 @@ -96,6 +96,9 @@ class TrainingConfig: # pylint: disable=too-many-instance-attributes precision: str = "16-mixed" float32_matmul_precision: str = "medium" log_dir: str = "./logs" + early_stopping_patience: Optional[int] = None + save_top_k: int = 1 + target_metric: str = "test_f1" @dataclasses.dataclass diff --git a/src/deepforest_finetuning/evaluation/_evaluate.py b/src/deepforest_finetuning/evaluation/_evaluate.py index fc62c35..0b156a9 100644 --- a/src/deepforest_finetuning/evaluation/_evaluate.py +++ b/src/deepforest_finetuning/evaluation/_evaluate.py @@ -4,15 +4,18 @@ from pathlib import Path import warnings -from typing import Union +from typing import Dict, Union from deepforest.evaluate import evaluate_boxes import pandas as pd def evaluate( - predictions: pd.DataFrame, annotations: pd.DataFrame, iou_threshold: float, output_file: Union[str, Path] -) -> None: + predictions: pd.DataFrame, + annotations: pd.DataFrame, + iou_threshold: float, + output_file: Union[str, Path], +) -> Dict[str, float]: """ Evaluates a model's predictions and stores the evaluation metrics as CSV file. @@ -22,9 +25,12 @@ def evaluate( iou_threshold: Threshold for the IoU between predicted and target bounding boxes at which predicted bounding boxes are counted as true positives. output_file: Path of the CSV file in which to store the evaluation metrics. + + Returns: + Dictionary containing the evaluation metrics (precision, recall, f1). """ - # ignore deprecated warnings from pandas raised by deepforest.IoU (line 113: iou_df = pd.concat(iou_df)) + # ignore deprecated pandas warnings raised by deepforest.IoU (line 113: iou_df = pd.concat(iou_df)) with warnings.catch_warnings(): warnings.simplefilter("ignore") results = evaluate_boxes( @@ -48,3 +54,10 @@ def evaluate( Path(output_file).parent.mkdir(exist_ok=True, parents=True) df.to_csv(output_file, index=False) + + # Return metrics dictionary for use with Lightning logger + return { + "precision": results["precision"], + "recall": results["recall"], + "f1": results["f1"], + } diff --git a/src/deepforest_finetuning/prediction/_prediction.py b/src/deepforest_finetuning/prediction/_prediction.py index de697cd..afd5c43 100644 --- a/src/deepforest_finetuning/prediction/_prediction.py +++ b/src/deepforest_finetuning/prediction/_prediction.py @@ -55,14 +55,12 @@ def prediction( if predict_tile: pred = model.predict_tile( image=tree_dataset[img_idx].astype(np.float32), - return_plot=False, patch_size=patch_size, patch_overlap=patch_overlap, ) else: pred = model.predict_image( image=tree_dataset[img_idx].astype(np.float32), - return_plot=False, ) image_name = tree_dataset.__getname__(img_idx) pred["image_path"] = image_name diff --git a/src/deepforest_finetuning/prediction/_prediction_dataset.py b/src/deepforest_finetuning/prediction/_prediction_dataset.py index 3fac452..4236a6d 100644 --- a/src/deepforest_finetuning/prediction/_prediction_dataset.py +++ b/src/deepforest_finetuning/prediction/_prediction_dataset.py @@ -42,7 +42,21 @@ def __getitem__(self, idx: int) -> npt.NDArray: Returns: Image data. """ img_path = self.image_files[idx] - image_array = np.array(imread(img_path))[:, :, :3].astype(np.uint8) + file_ext = Path(img_path).suffix.lower() + + if file_ext in [".tif", ".tiff"]: + # Use tifffile for TIFF images + image_array = np.array(imread(img_path)) + else: + # Use PIL for other image formats (PNG, JPG, etc.) + image = Image.open(img_path) + image_array = np.array(image) + + # Ensure we only take the first 3 channels if there are more + if image_array.ndim >= 3 and image_array.shape[2] > 3: + image_array = image_array[:, :, :3] + + image_array = image_array.astype(np.uint8) if self.resize_images_to is not None: image = Image.fromarray(image_array) diff --git a/src/deepforest_finetuning/preprocessing/_filter_labels.py b/src/deepforest_finetuning/preprocessing/_filter_labels.py index a758bd6..88ea6c3 100644 --- a/src/deepforest_finetuning/preprocessing/_filter_labels.py +++ b/src/deepforest_finetuning/preprocessing/_filter_labels.py @@ -38,7 +38,12 @@ def filter_bounding_boxed_with_size_based_nms(coco_json: Dict[str, Any], iou_thr for annotation in coco_json["annotations"]: bounding_box = annotation["bbox"] bounding_boxes.append( - [bounding_box[0], bounding_box[1], bounding_box[0] + bounding_box[2], bounding_box[1] + bounding_box[3]] + [ + bounding_box[0], + bounding_box[1], + bounding_box[0] + bounding_box[2], + bounding_box[1] + bounding_box[3], + ] ) bounding_box_sizes.append(bounding_box[2] * bounding_box[3]) diff --git a/src/deepforest_finetuning/preprocessing/_preprocess_manually_corrected_labels.py b/src/deepforest_finetuning/preprocessing/_preprocess_manually_corrected_labels.py index dd3e5aa..877eb11 100644 --- a/src/deepforest_finetuning/preprocessing/_preprocess_manually_corrected_labels.py +++ b/src/deepforest_finetuning/preprocessing/_preprocess_manually_corrected_labels.py @@ -12,7 +12,11 @@ import rasterio from deepforest_finetuning.config import ManuallyCorrectedLabelPreprocessingConfig -from deepforest_finetuning.utils import annotations_to_coco, get_image_size_from_pascal_voc, rescale_coco_json +from deepforest_finetuning.utils import ( + annotations_to_coco, + get_image_size_from_pascal_voc, + rescale_coco_json, +) def preprocess_manually_corrected_labels( # pylint: disable=too-many-locals, too-many-nested-blocks, too-many-statements @@ -55,7 +59,9 @@ def preprocess_manually_corrected_labels( # pylint: disable=too-many-locals, to coco_json = annotations_to_coco(annotations, image_width, image_height, capture_date=capture_date) coco_json = rescale_coco_json( - coco_json, target_image_path, source_image_shape=np.array([image_height, image_width]) + coco_json, + target_image_path, + source_image_shape=np.array([image_height, image_width]), ) output_file_path = (output_label_folder / f"{Path(image_path).stem}_coco").with_suffix(".json") @@ -111,9 +117,15 @@ def preprocess_manually_corrected_labels( # pylint: disable=too-many-locals, to continue clipped_x_min = max(int((x_min_meter - target_top_left[0]) / target_pixel_size[0]), 0) - clipped_x_max = min(int((x_max_meter - target_top_left[0]) / target_pixel_size[0]), target_width) + clipped_x_max = min( + int((x_max_meter - target_top_left[0]) / target_pixel_size[0]), + target_width, + ) clipped_y_min = max(int((target_top_left[1] - y_min_meter) / target_pixel_size[1]), 0) - clipped_y_max = min(int((target_top_left[1] - y_max_meter) / target_pixel_size[1]), target_height) + clipped_y_max = min( + int((target_top_left[1] - y_max_meter) / target_pixel_size[1]), + target_height, + ) clipped_annotation = deepcopy(annotation) clipped_annotation["bbox"] = [ diff --git a/src/deepforest_finetuning/preprocessing/_project_point_cloud_labels.py b/src/deepforest_finetuning/preprocessing/_project_point_cloud_labels.py index a709d36..77adb23 100644 --- a/src/deepforest_finetuning/preprocessing/_project_point_cloud_labels.py +++ b/src/deepforest_finetuning/preprocessing/_project_point_cloud_labels.py @@ -9,7 +9,11 @@ import numpy as np from pointtorch import read from pointtorch.operations.numpy import make_labels_consecutive -from pointtree.operations import cloth_simulation_filtering, create_digital_terrain_model, distance_to_dtm +from pointtree.operations import ( + cloth_simulation_filtering, + create_digital_terrain_model, + distance_to_dtm, +) import rasterio from rasterio.transform import from_origin from skimage.filters.rank import modal @@ -85,7 +89,8 @@ def project_point_cloud_labels( # pylint: disable=too-many-locals, too-many-sta label_image = np.zeros(label_image_shape, dtype=np.int64) valid_mask = np.logical_and( - (pixel_indices >= 0).all(axis=-1), (pixel_indices < np.flip(label_image_shape)).all(axis=-1) + (pixel_indices >= 0).all(axis=-1), + (pixel_indices < np.flip(label_image_shape)).all(axis=-1), ) pixel_indices = pixel_indices[valid_mask] @@ -116,7 +121,9 @@ def project_point_cloud_labels( # pylint: disable=too-many-locals, too-many-sta ) max_height, max_indices = scatter_max( - torch.from_numpy(dist_to_dtm), torch.from_numpy(inverse_indices).long(), dim=-1 + torch.from_numpy(dist_to_dtm), + torch.from_numpy(inverse_indices).long(), + dim=-1, ) instance_ids = valid_point_cloud["instance_id_prediction"].to_numpy()[max_indices.cpu().numpy()] @@ -182,7 +189,13 @@ def project_point_cloud_labels( # pylint: disable=too-many-locals, too-many-sta continue bounding_box = [x_min, y_min, width, height] - annotation = {"id": next_id, "image_id": 0, "category_id": 0, "iscrowd": 0, "bbox": bounding_box} + annotation = { + "id": next_id, + "image_id": 0, + "category_id": 0, + "iscrowd": 0, + "bbox": bounding_box, + } annotation["segmentation"] = coco_bbox_to_polygon(bounding_box) annotations.append(annotation) next_id += 1 @@ -195,9 +208,17 @@ def project_point_cloud_labels( # pylint: disable=too-many-locals, too-many-sta image_capture_date = "" coco_json = { - "info": {"year": "2024", "version": "1.0.0", "date_created": datetime.today().strftime("%Y-%m-%d")}, + "info": { + "year": "2024", + "version": "1.0.0", + "date_created": datetime.today().strftime("%Y-%m-%d"), + }, "licenses": [ - {"id": 0, "name": "Attribution License", "url": "https://creativecommons.org/licenses/by/4.0/"} + { + "id": 0, + "name": "Attribution License", + "url": "https://creativecommons.org/licenses/by/4.0/", + } ], "images": [ { diff --git a/src/deepforest_finetuning/preprocessing/_rescale_images.py b/src/deepforest_finetuning/preprocessing/_rescale_images.py index 8daab5d..33ab7e9 100644 --- a/src/deepforest_finetuning/preprocessing/_rescale_images.py +++ b/src/deepforest_finetuning/preprocessing/_rescale_images.py @@ -39,7 +39,12 @@ def rescale_images(config: ImageRescalingConfig): # pylint: disable=too-many-lo with rasterio.open(original_image_path) as src: transform, width, height = calculate_default_transform( - src.crs, src.crs, src.width, src.height, *src.bounds, resolution=target_resolution + src.crs, + src.crs, + src.width, + src.height, + *src.bounds, + resolution=target_resolution, ) input_pixel_size = np.abs(np.array([src.transform[0], src.transform[4]], dtype=np.float64)) @@ -73,6 +78,7 @@ def rescale_images(config: ImageRescalingConfig): # pylint: disable=too-many-lo label_output_folder.mkdir(exist_ok=True, parents=True) label_subfolders = [x for x in os.listdir(input_label_folder) if os.path.isdir(input_label_folder / x)] + label_subfolders.append(".") for label_subfolder in label_subfolders: label_file_name = f"{original_image_path.stem}_coco.json" label_file = input_label_folder / label_subfolder / label_file_name @@ -86,7 +92,11 @@ def rescale_images(config: ImageRescalingConfig): # pylint: disable=too-many-lo with open(label_file, "r", encoding="utf-8") as f: coco_json = json.load(f) - coco_json = rescale_coco_json(coco_json, target_image_path, source_image_path=original_image_path) + coco_json = rescale_coco_json( + coco_json, + target_image_path, + source_image_path=original_image_path, + ) with open(target_label_path, "w", encoding="utf-8") as f: json.dump(coco_json, f, indent=4) diff --git a/src/deepforest_finetuning/training/_finetuning.py b/src/deepforest_finetuning/training/_finetuning.py index 2459dc3..9c72725 100644 --- a/src/deepforest_finetuning/training/_finetuning.py +++ b/src/deepforest_finetuning/training/_finetuning.py @@ -3,7 +3,6 @@ __all__ = ["split_images_into_patches", "finetuning"] import copy -from functools import partial import os from pathlib import Path import shutil @@ -15,7 +14,7 @@ from deepforest import utilities, preprocess from deepforest import main as deepforest_main from pytorch_lightning import seed_everything, Trainer, LightningModule -from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import CSVLogger from pytorch_lightning.utilities.seed import isolate_rng import numpy as np @@ -27,29 +26,19 @@ from deepforest_finetuning.prediction import prediction as run_prediction -def get_transform(augment: bool, seed: Optional[int] = None): +def get_transform(seed: Optional[int] = None): """ Albumentations transformation of bounding boxes. - Args: - augment: Whether to apply data augmentations. - seed: Random seed for data augmentations to ensure reproducibility. Defaults to :code:`None`. - Returns: Transforms. """ - if augment: - transform = A.Compose( - [A.HorizontalFlip(p=0.5), ToTensorV2()], - bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_ids"]), - seed=seed, - ) - - else: - transform = A.Compose( - [ToTensorV2()], bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_ids"]), seed=seed - ) + transform = A.Compose( + [A.HorizontalFlip(p=0.5), ToTensorV2()], + bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_ids"]), + seed=seed, + ) return transform @@ -107,6 +96,36 @@ def split_images_into_patches( return annotations_path +def _collect_annotation_paths(base_dir: Path, annotation_files: Union[List[str], str]) -> List[str]: + """ + Process annotation file paths, handling both individual files and directories. + + Args: + base_dir: Base directory for relative paths. + annotation_files: List of annotation file paths or directories containing annotation files. + + Returns: + List of processed annotation file paths. + """ + annotation_paths = [] + + if isinstance(annotation_files, str): + annotation_files = [annotation_files] + + for file_path in annotation_files: + path = base_dir / file_path + if path.is_dir(): + json_files = list(path.glob("*.json")) + relative_paths = [str(json_file.relative_to(base_dir)) for json_file in json_files] + annotation_paths.extend(relative_paths) + print(f"INFO: Found {len(json_files)} JSON files in directory {path}.") + else: + annotation_paths.append(file_path) + print(f"INFO: Using annotation file {file_path}.") + + return annotation_paths + + class EvaluationCallBack(Callback): """ Callback that evaluates the model after each training epoch. @@ -122,7 +141,7 @@ def __init__(self, config: TrainingConfig, seed: int): self._base_dir = Path(config.base_dir) self._seed = seed - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): # pylint: disable=too-many-locals """ Hook that evaluates the model after each training epoch. @@ -131,14 +150,17 @@ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): pl_module: Model to be evaluated. """ - trainer = copy.deepcopy(trainer) - pl_module = copy.deepcopy(pl_module) - pl_module.trainer = trainer - # trainer_state = copy.deepcopy(trainer.state) + # Create a copy of the model for evaluation to avoid affecting the random state of the training + eval_model = copy.deepcopy(pl_module) + eval_model.trainer = copy.deepcopy(trainer) + # evaluate on training and test set + train_annotation_files = _collect_annotation_paths(self._base_dir, self._config.train_annotation_files) + test_annotation_files = _collect_annotation_paths(self._base_dir, self._config.test_annotation_files) + for prefix, annotation_files in [ - ("train", self._config.train_annotation_files), - ("test", self._config.test_annotation_files), + ("train", train_annotation_files), + ("test", test_annotation_files), ]: image_files = [] annotations = [] @@ -161,7 +183,7 @@ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): export_config.output_file_name = f"{prefix}_predictions_seed_{self._seed}.csv" run_prediction( - pl_module, + eval_model, image_files=image_files, predict_tile=True, patch_size=self._config.patch_size, @@ -180,15 +202,30 @@ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): / f"{trainer.current_epoch + 1}_epochs" / f"{prefix}_metrics_seed_{self._seed}.csv" ) - evaluate( + metrics = evaluate( prediction, pd.concat(annotations), self._config.iou_threshold, metrics_file, ) + # Log metrics to Lightning logger for each prefix (train/test) + # This makes them available for callbacks like EarlyStopping + for metric_name, metric_value in metrics.items(): + if trainer.logger is not None: + trainer.logger.log_metrics( + {f"{prefix}_{metric_name}": metric_value}, + step=trainer.current_epoch, + ) + + if metric_name == "f1": + # Access callback_metrics directly on the original trainer + trainer.callback_metrics[f"{prefix}_{metric_name}"] = torch.tensor(metric_value) -def finetuning(config: TrainingConfig): # pylint: disable=too-many-locals, too-many-statements + +def finetuning( + config: TrainingConfig, +): # pylint: disable=too-many-locals, too-many-statements """Fine-tunes the DeepForest model.""" torch.set_float32_matmul_precision(config.float32_matmul_precision) @@ -201,9 +238,12 @@ def finetuning(config: TrainingConfig): # pylint: disable=too-many-locals, too- preprocessed_image_folders = {} preprocessed_annotation_files = {} - splitting_configs = [("train", config.train_annotation_files)] + train_annotation_files = _collect_annotation_paths(base_dir, config.train_annotation_files) + + splitting_configs = [("train", train_annotation_files)] if config.pretrain_annotation_files is not None and len(config.pretrain_annotation_files) > 0: - splitting_configs.append(("pretraining", config.pretrain_annotation_files)) + pretrain_annotation_files = _collect_annotation_paths(base_dir, config.pretrain_annotation_files) + splitting_configs.append(("pretraining", pretrain_annotation_files)) for prefix, annotation_files in splitting_configs: annotations = [] @@ -226,6 +266,9 @@ def finetuning(config: TrainingConfig): # pylint: disable=too-many-locals, too- print("\nStarting training ...") + if config.early_stopping_patience is not None: + print(f"Early stopping enabled with patience of {config.early_stopping_patience} epochs") + for seed in config.seeds: # set seeds for reproducibility seed_everything(seed, workers=True, verbose=True) @@ -234,7 +277,7 @@ def finetuning(config: TrainingConfig): # pylint: disable=too-many-locals, too- print(f"INFO: Training for {config.epochs} epochs with seed {seed}...") # load model - model = deepforest_main.deepforest(transforms=partial(get_transform, seed=seed)) + model = deepforest_main.deepforest(transforms=get_transform(seed=seed)) model.use_release() # copy config to avoid overwriting @@ -242,23 +285,39 @@ def finetuning(config: TrainingConfig): # pylint: disable=too-many-locals, too- # configure model if current_config.pretrain_learning_rate is None: - model.config["train"]["lr"] = current_config.learning_rate + model.config.train.lr = current_config.learning_rate else: - model.config["train"]["lr"] = current_config.pretrain_learning_rate + model.config.train.lr = current_config.pretrain_learning_rate + + model.config.train.epochs = config.epochs - model.config["train"]["epochs"] = config.epochs - model.config["save-snapshot"] = False + if "pretraining" in preprocessed_annotation_files: + model.config["train"]["csv_file"] = preprocessed_annotation_files["pretraining"] + model.config["train"]["root_dir"] = preprocessed_image_folders["pretraining"] + logger = CSVLogger( + config.log_dir, + name=f"{config.epochs}_epochs_seed_{seed}_pretraining", + ) + + pretraining_callbacks = [] + if config.early_stopping_patience is not None: + pretraining_callbacks.append( + EarlyStopping( + monitor=config.target_metric, + min_delta=0.0, + patience=config.early_stopping_patience, + verbose=True, + mode="min" if "loss" in config.target_metric else "max", + ) + ) - if "pretrain" in annotation_files: - model.config["train"]["csv_file"] = preprocessed_annotation_files["pretrain"] - model.config["train"]["root_dir"] = preprocessed_image_folders["pretrain"] - logger = CSVLogger(config.log_dir, name=f"{config.epochs}_epochs_seed_{seed}_pretraining") model.create_trainer( precision=config.precision if torch.cuda.is_available() else 32, log_every_n_steps=1, benchmark=False, deterministic=True, logger=logger, + callbacks=pretraining_callbacks if pretraining_callbacks else None, ) model.trainer.fit(model) @@ -269,14 +328,26 @@ def finetuning(config: TrainingConfig): # pylint: disable=too-many-locals, too- callbacks: List[Callback] = [EvaluationCallBack(config, seed)] if config.checkpoint_dir is not None: - callbacks.append( ModelCheckpoint( dirpath=base_dir / config.checkpoint_dir, filename="{epoch}_" + f"seed={seed}", - save_top_k=-1, + monitor=config.target_metric, + save_top_k=config.save_top_k, every_n_epochs=1, enable_version_counter=False, + mode="min" if "loss" in config.target_metric else "max", + ) + ) + + if config.early_stopping_patience is not None: + callbacks.append( + EarlyStopping( + monitor=config.target_metric, + min_delta=0.0, + patience=config.early_stopping_patience, + verbose=True, + mode="min" if "loss" in config.target_metric else "max", ) ) diff --git a/src/deepforest_finetuning/utils/_annotations_to_coco.py b/src/deepforest_finetuning/utils/_annotations_to_coco.py index 84365a9..0250fae 100644 --- a/src/deepforest_finetuning/utils/_annotations_to_coco.py +++ b/src/deepforest_finetuning/utils/_annotations_to_coco.py @@ -11,7 +11,10 @@ def annotations_to_coco( - annotations: pd.DataFrame, image_width: int, image_height: int, capture_date: Optional[str] = None + annotations: pd.DataFrame, + image_width: int, + image_height: int, + capture_date: Optional[str] = None, ) -> Dict[str, Any]: """ Converts DeepForest annotations into COCO format. @@ -72,8 +75,18 @@ def annotations_to_coco( next_id += 1 coco_json = { - "info": {"year": "2024", "version": "1.0.0", "date_created": datetime.today().strftime("%Y-%m-%d")}, - "licenses": [{"id": 0, "name": "Attribution License", "url": "https://creativecommons.org/licenses/by/4.0/"}], + "info": { + "year": "2024", + "version": "1.0.0", + "date_created": datetime.today().strftime("%Y-%m-%d"), + }, + "licenses": [ + { + "id": 0, + "name": "Attribution License", + "url": "https://creativecommons.org/licenses/by/4.0/", + } + ], "images": [ { "id": 0,