From e70961f1bc4352bbe14003fabb24bfa110d05cce Mon Sep 17 00:00:00 2001 From: Lee Clement Date: Mon, 20 Apr 2026 18:12:42 -0230 Subject: [PATCH] Validate task type before model upload --- roboflow/config.py | 6 ++ roboflow/core/version.py | 26 +++++++- roboflow/core/workspace.py | 2 +- roboflow/util/model_processor.py | 98 +++++++++++++++++++++++++----- tests/test_version.py | 49 +++++++++++++++ tests/util/test_model_processor.py | 74 ++++++++++++++++++++++ 6 files changed, 237 insertions(+), 18 deletions(-) create mode 100644 tests/util/test_model_processor.py diff --git a/roboflow/config.py b/roboflow/config.py index 4f683ddc..bc3eaf03 100644 --- a/roboflow/config.py +++ b/roboflow/config.py @@ -73,6 +73,12 @@ def get_conditional_configuration_variable(key, default): TYPE_SEMANTIC_SEGMENTATION = "semantic-segmentation" TYPE_KEYPOINT_DETECTION = "keypoint-detection" +TASK_DET = "det" +TASK_SEG = "seg" +TASK_POSE = "pose" +TASK_CLS = "cls" +TASK_OBB = "obb" + DEFAULT_BATCH_NAME = "Pip Package Upload" DEFAULT_JOB_NAME = "Annotated via API" diff --git a/roboflow/core/version.py b/roboflow/core/version.py index 5e236c61..a6976e25 100644 --- a/roboflow/core/version.py +++ b/roboflow/core/version.py @@ -16,6 +16,10 @@ API_URL, APP_URL, DEMO_KEYS, + TASK_CLS, + TASK_DET, + TASK_POSE, + TASK_SEG, TQDM_DISABLE, TYPE_CLASSICATION, TYPE_INSTANCE_SEGMENTATION, @@ -32,7 +36,7 @@ from roboflow.models.semantic_segmentation import SemanticSegmentationModel from roboflow.util.annotations import amend_data_yaml from roboflow.util.general import extract_zip, write_line -from roboflow.util.model_processor import process +from roboflow.util.model_processor import process, task_of_model_type from roboflow.util.versions import get_model_format, get_wrong_dependencies_versions, normalize_yolo_model_type if TYPE_CHECKING: @@ -486,13 +490,31 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best filename (str, optional): The name of the weights file. Defaults to "weights/best.pt". """ model_type = normalize_yolo_model_type(model_type) - zip_file_name = process(model_type, model_path, filename) + zip_file_name, model_type = process(model_type, model_path, filename) if zip_file_name is None: raise RuntimeError("Failed to process model") + self._validate_against_project_type(model_type) self._upload_zip(model_type, model_path, zip_file_name) + def _validate_against_project_type(self, model_type: str) -> None: + # TYPE_SEMANTIC_SEGMENTATION intentionally omitted — no uploader emits it. + expected = { + TYPE_OBJECT_DETECTION: TASK_DET, + TYPE_INSTANCE_SEGMENTATION: TASK_SEG, + TYPE_KEYPOINT_DETECTION: TASK_POSE, + TYPE_CLASSICATION: TASK_CLS, + }.get(self.type) + if expected is None: + return + actual = task_of_model_type(model_type) + if actual != expected: + raise ValueError( + f"Project '{self.project}' is type '{self.type}' (task '{expected}') " + f"but model_type '{model_type}' implies task '{actual}'." + ) + def _upload_zip(self, model_type: str, model_path: str, model_file_name: str): res = requests.get( f"{API_URL}/{self.workspace}/{self.project}/{self.version}" diff --git a/roboflow/core/workspace.py b/roboflow/core/workspace.py index 6790a20b..5b9a81f8 100644 --- a/roboflow/core/workspace.py +++ b/roboflow/core/workspace.py @@ -633,7 +633,7 @@ def deploy_model( raise ValueError(f"Project {project_id} is not accessible in this workspace") model_type = normalize_yolo_model_type(model_type) - zip_file_name = process(model_type, model_path, filename) + zip_file_name, model_type = process(model_type, model_path, filename) if zip_file_name is None: raise RuntimeError("Failed to process model") diff --git a/roboflow/util/model_processor.py b/roboflow/util/model_processor.py index cd0fe2db..c4486ae5 100644 --- a/roboflow/util/model_processor.py +++ b/roboflow/util/model_processor.py @@ -2,14 +2,28 @@ import os import shutil import zipfile -from typing import Callable +from typing import Callable, Optional import yaml +from roboflow.config import TASK_CLS, TASK_DET, TASK_OBB, TASK_POSE, TASK_SEG from roboflow.util.versions import print_warn_for_wrong_dependencies_versions -def process(model_type: str, model_path: str, filename: str) -> str: +def task_of_model_type(model_type: str) -> str: + """Canonical task for a deploy model_type string. + + Non-detect tasks double as the model_type suffix token + (e.g. 'yolov11-seg' -> TASK_SEG). Plain 'yolov11' / 'rfdetr-base' -> TASK_DET. + """ + s = model_type.lower() + for task in (TASK_SEG, TASK_POSE, TASK_CLS, TASK_OBB): + if task in s: + return task + return TASK_DET + + +def process(model_type: str, model_path: str, filename: str) -> tuple[str, str]: processor = _get_processor_function(model_type) return processor(model_type, model_path, filename) @@ -66,7 +80,20 @@ def _get_processor_function(model_type: str) -> Callable: return _process_yolo -def _process_yolo(model_type: str, model_path: str, filename: str) -> str: +def _detect_yolo_task(model_instance) -> Optional[str]: + """Detect the training task of an Ultralytics model instance via its class name.""" + if model_instance is None: + return None + return { + "DetectionModel": TASK_DET, + "SegmentationModel": TASK_SEG, + "PoseModel": TASK_POSE, + "ClassificationModel": TASK_CLS, + "OBBModel": TASK_OBB, + }.get(type(model_instance).__name__) + + +def _process_yolo(model_type: str, model_path: str, filename: str) -> tuple[str, str]: if "yolov8" in model_type: try: import torch @@ -148,6 +175,17 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str: model_instance = model["model"] if "model" in model and model["model"] is not None else model["ema"] + detected_task = _detect_yolo_task(model_instance) + if detected_task: + existing_task = task_of_model_type(model_type) + if existing_task == TASK_DET and detected_task != TASK_DET: + model_type = f"{model_type}-{detected_task}" + elif existing_task != detected_task: + raise ValueError( + f"model_type '{model_type}' implies task '{existing_task}' but the " + f".pt file is a '{detected_task}' checkpoint. Use a matching model_type." + ) + if isinstance(model_instance.names, list): class_names = model_instance.names else: @@ -241,10 +279,26 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str: if file in ["model_artifacts.json", "state_dict.pt"]: raise (ValueError(f"File {file} not found. Please make sure to provide a valid model path.")) - return zip_file_name + return zip_file_name, model_type + + +def _detect_rfdetr_task(checkpoint) -> Optional[str]: + """Detect the training task of an rf-detr checkpoint via `model_name`. + rf-detr currently only supports weight upload for detection and instance + segmentation. Every checkpoint stores the Python class name of the model + (e.g. 'RFDETRNano' vs 'RFDETRSegNano') at `checkpoint["model_name"]`; + this is also what rf-detr's own loader uses to pick a model class. + """ + if not isinstance(checkpoint, dict): + return None + model_name = checkpoint.get("model_name") + if not isinstance(model_name, str): + return None + return TASK_SEG if TASK_SEG in model_name.lower() else TASK_DET -def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str: + +def _process_rfdetr(model_type: str, model_path: str, filename: str) -> tuple[str, str]: _supported_types = [ # Detection models "rfdetr-base", @@ -274,7 +328,20 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str: if pt_file is None: raise RuntimeError("No .pt or .pth model file found in the provided path") - get_classnames_txt_for_rfdetr(model_path, pt_file) + import torch + + checkpoint = torch.load(os.path.join(model_path, pt_file), map_location="cpu", weights_only=False) + + detected_task = _detect_rfdetr_task(checkpoint) + if detected_task: + implied_task = task_of_model_type(model_type) + if detected_task != implied_task: + raise ValueError( + f"model_type '{model_type}' implies task '{implied_task}' but the " + f".pt is a '{detected_task}' rfdetr checkpoint. Use a matching model_type." + ) + + get_classnames_txt_for_rfdetr(model_path, pt_file, checkpoint=checkpoint) # Copy the .pt file to weights.pt if not already named weights.pt if pt_file != "weights.pt": @@ -293,19 +360,20 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str: if os.path.exists(os.path.join(model_path, file)): zipMe.write(os.path.join(model_path, file), arcname=file, compress_type=zipfile.ZIP_DEFLATED) - return zip_file_name + return zip_file_name, model_type -def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str): +def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str, checkpoint=None): class_names_path = os.path.join(model_path, "class_names.txt") if os.path.exists(class_names_path): maybe_prepend_dummy_class(class_names_path) return class_names_path - import torch + if checkpoint is None: + import torch - model = torch.load(os.path.join(model_path, pt_file), map_location="cpu", weights_only=False) - args = vars(model["args"]) + checkpoint = torch.load(os.path.join(model_path, pt_file), map_location="cpu", weights_only=False) + args = vars(checkpoint["args"]) if "class_names" in args: with open(class_names_path, "w") as f: for class_name in args["class_names"]: @@ -335,7 +403,7 @@ def maybe_prepend_dummy_class(class_name_file: str): def _process_huggingface( model_type: str, model_path: str, filename: str = "fine-tuned-paligemma-3b-pt-224.f16.npz" -) -> str: +) -> tuple[str, str]: # Check if model_path exists if not os.path.exists(model_path): raise FileNotFoundError(f"Model path {model_path} does not exist.") @@ -382,10 +450,10 @@ def _process_huggingface( print("Uploading to Roboflow... May take several minutes.") - return tar_file_name + return tar_file_name, model_type -def _process_yolonas(model_type: str, model_path: str, filename: str = "weights/best.pt") -> str: +def _process_yolonas(model_type: str, model_path: str, filename: str = "weights/best.pt") -> tuple[str, str]: try: import torch except ImportError: @@ -449,4 +517,4 @@ def _process_yolonas(model_type: str, model_path: str, filename: str = "weights/ if file in ["model_artifacts.json", filename]: raise (ValueError(f"File {file} not found. Please make sure to provide a valid model path.")) - return zip_file_name + return zip_file_name, model_type diff --git a/tests/test_version.py b/tests/test_version.py index 8cd5b69c..b8cab69c 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -5,6 +5,12 @@ import responses from roboflow.adapters import rfapi +from roboflow.config import ( + TYPE_CLASSICATION, + TYPE_INSTANCE_SEGMENTATION, + TYPE_KEYPOINT_DETECTION, + TYPE_OBJECT_DETECTION, +) from roboflow.core.version import Version, unwrap_version_id from tests.helpers import get_version @@ -197,3 +203,46 @@ def test_unwrap_version_id_when_only_version_id_is_given() -> None: # then assert result == "3" + + +class TestValidateAgainstProjectType(unittest.TestCase): + def _version(self, project_type): + return get_version(type=project_type) + + def test_detection_project_accepts_plain_yolo(self): + self._version(TYPE_OBJECT_DETECTION)._validate_against_project_type("yolov11") + + def test_detection_project_accepts_rfdetr_detection(self): + self._version(TYPE_OBJECT_DETECTION)._validate_against_project_type("rfdetr-medium") + + def test_detection_project_rejects_seg_model(self): + with self.assertRaises(ValueError): + self._version(TYPE_OBJECT_DETECTION)._validate_against_project_type("yolov11-seg") + + def test_detection_project_rejects_rfdetr_seg(self): + with self.assertRaises(ValueError): + self._version(TYPE_OBJECT_DETECTION)._validate_against_project_type("rfdetr-seg-medium") + + def test_instance_seg_project_accepts_seg_model(self): + self._version(TYPE_INSTANCE_SEGMENTATION)._validate_against_project_type("yolov11-seg") + + def test_instance_seg_project_accepts_rfdetr_seg(self): + self._version(TYPE_INSTANCE_SEGMENTATION)._validate_against_project_type("rfdetr-seg-medium") + + def test_instance_seg_project_rejects_detection(self): + with self.assertRaises(ValueError): + self._version(TYPE_INSTANCE_SEGMENTATION)._validate_against_project_type("yolov11") + + def test_keypoint_project_accepts_pose_model(self): + self._version(TYPE_KEYPOINT_DETECTION)._validate_against_project_type("yolov11-pose") + + def test_keypoint_project_rejects_detection(self): + with self.assertRaises(ValueError): + self._version(TYPE_KEYPOINT_DETECTION)._validate_against_project_type("yolov11") + + def test_classification_project_accepts_cls(self): + self._version(TYPE_CLASSICATION)._validate_against_project_type("yolov11-cls") + + def test_classification_project_rejects_detection(self): + with self.assertRaises(ValueError): + self._version(TYPE_CLASSICATION)._validate_against_project_type("yolov11") diff --git a/tests/util/test_model_processor.py b/tests/util/test_model_processor.py new file mode 100644 index 00000000..1d268e53 --- /dev/null +++ b/tests/util/test_model_processor.py @@ -0,0 +1,74 @@ +import unittest + +from roboflow.config import TASK_CLS, TASK_DET, TASK_OBB, TASK_POSE, TASK_SEG +from roboflow.util.model_processor import ( + _detect_rfdetr_task, + _detect_yolo_task, + task_of_model_type, +) + + +class _FakeModel: + """Stand-in for an Ultralytics model_instance; only __class__.__name__ matters.""" + + +def _make_fake(name: str): + return type(name, (_FakeModel,), {})() + + +class TaskOfModelTypeTest(unittest.TestCase): + def test_detect_defaults(self): + self.assertEqual(task_of_model_type("yolov11"), TASK_DET) + self.assertEqual(task_of_model_type("rfdetr-base"), TASK_DET) + self.assertEqual(task_of_model_type("rfdetr-medium"), TASK_DET) + self.assertEqual(task_of_model_type("yolov8"), TASK_DET) + + def test_segment(self): + self.assertEqual(task_of_model_type("yolov11-seg"), TASK_SEG) + self.assertEqual(task_of_model_type("rfdetr-seg-medium"), TASK_SEG) + self.assertEqual(task_of_model_type("yolov7-seg"), TASK_SEG) + + def test_pose(self): + self.assertEqual(task_of_model_type("yolov11-pose"), TASK_POSE) + + def test_classify(self): + self.assertEqual(task_of_model_type("yolov11-cls"), TASK_CLS) + + def test_obb(self): + self.assertEqual(task_of_model_type("yolov11-obb"), TASK_OBB) + + +class DetectYoloTaskTest(unittest.TestCase): + def test_ultralytics_class_names(self): + cases = { + "SegmentationModel": TASK_SEG, + "PoseModel": TASK_POSE, + "ClassificationModel": TASK_CLS, + "OBBModel": TASK_OBB, + "DetectionModel": TASK_DET, + } + for cls_name, expected in cases.items(): + self.assertEqual(_detect_yolo_task(_make_fake(cls_name)), expected, cls_name) + + def test_unrecognized_returns_none(self): + self.assertIsNone(_detect_yolo_task(_make_fake("SomeOtherModel"))) + self.assertIsNone(_detect_yolo_task(None)) + + +class DetectRfdetrTaskTest(unittest.TestCase): + def test_segmentation_model_names(self): + for name in ("RFDETRSegNano", "RFDETRSegSmall", "RFDETRSegMedium", "RFDETRSegLarge"): + self.assertEqual(_detect_rfdetr_task({"model_name": name}), TASK_SEG, name) + + def test_detection_model_names(self): + for name in ("RFDETRNano", "RFDETRSmall", "RFDETRMedium", "RFDETRLarge", "RFDETRXLarge"): + self.assertEqual(_detect_rfdetr_task({"model_name": name}), TASK_DET, name) + + def test_unrecognized_returns_none(self): + self.assertIsNone(_detect_rfdetr_task({})) + self.assertIsNone(_detect_rfdetr_task({"model_name": None})) + self.assertIsNone(_detect_rfdetr_task({"args": {"segmentation_head": True}})) + + +if __name__ == "__main__": + unittest.main()