From f9276350aef4a1953df63f1574a360608bb28bf0 Mon Sep 17 00:00:00 2001 From: Rodrigo Barbosa Date: Mon, 11 Aug 2025 09:53:41 -0300 Subject: [PATCH 1/2] Add model_type parameter support for training - Add model_type parameter to CLI train command - Update version.py to handle model_type in train method - Snake_case naming convention for model_type parameter - Simplify CLI workspace handling - Maintain API payload order minimal --- roboflow/core/version.py | 8 +++++++- roboflow/roboflowpy.py | 41 ++++++++++++++++++++++++++++++++++++++++ tests/manual/debugme.py | 34 ++++++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/roboflow/core/version.py b/roboflow/core/version.py index f6ca0fc7..7e63b103 100644 --- a/roboflow/core/version.py +++ b/roboflow/core/version.py @@ -290,12 +290,13 @@ def export(self, model_format=None): except json.JSONDecodeError: response.raise_for_status() - def train(self, speed=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel: + def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel: """ Ask the Roboflow API to train a previously exported version's dataset. Args: speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`. + model_type: The type of model to train. Default depends on kind of project. It takes precedence over speed. You can check the list of model ids by sending an invalid parameter in this argument. checkpoint: A string representing the checkpoint to use while training plot: Whether to plot the training results. Default is `False`. @@ -328,12 +329,17 @@ def train(self, speed=None, checkpoint=None, plot_in_notebook=False) -> Inferenc url = f"{API_URL}/{workspace}/{project}/{self.version}/train" data = {} + if speed: data["speed"] = speed if checkpoint: data["checkpoint"] = checkpoint + if model_type: + # API expects camelCase key + data["modelType"] = model_type + write_line("Reaching out to Roboflow to start training...") response = requests.post(url, json=data, params={"api_key": self.__api_key}) diff --git a/roboflow/roboflowpy.py b/roboflow/roboflowpy.py index 48d6c0d7..70cf6db9 100755 --- a/roboflow/roboflowpy.py +++ b/roboflow/roboflowpy.py @@ -19,6 +19,15 @@ def login(args): roboflow.login(force=args.force) +def train(args): + rf = roboflow.Roboflow() + workspace = rf.workspace(args.workspace) # handles None internally + project = workspace.project(args.project) + version = project.version(args.version_number) + model = version.train(model_type=args.model_type, checkpoint=args.checkpoint) + print(model) + + def _parse_url(url): regex = r"(?:https?://)?(?:universe|app)\.roboflow\.(?:com|one)/([^/]+)/([^/]+)(?:/dataset)?(?:/(\d+))?|([^/]+)/([^/]+)(?:/(\d+))?" # noqa: E501 match = re.match(regex, url) @@ -198,6 +207,7 @@ def _argparser(): subparsers = parser.add_subparsers(title="subcommands") _add_login_parser(subparsers) _add_download_parser(subparsers) + _add_train_parser(subparsers) _add_upload_parser(subparsers) _add_import_parser(subparsers) _add_infer_parser(subparsers) @@ -310,6 +320,37 @@ def _add_upload_parser(subparsers): upload_parser.set_defaults(func=upload_image) +def _add_train_parser(subparsers): + train_parser = subparsers.add_parser("train", help="Train a model for a dataset version") + train_parser.add_argument( + "-w", + dest="workspace", + help="specify a workspace url or id (will use default workspace if not specified)", + ) + train_parser.add_argument( + "-p", + dest="project", + help="project_id to train the model for", + ) + train_parser.add_argument( + "-v", + dest="version_number", + type=int, + help="version number to train", + ) + train_parser.add_argument( + "-t", + dest="model_type", + help="type of the model to train (e.g., rfdetr-nano, yolov8n)", + ) + train_parser.add_argument( + "--checkpoint", + dest="checkpoint", + help="checkpoint to resume training from", + ) + train_parser.set_defaults(func=train) + + def _add_import_parser(subparsers): import_parser = subparsers.add_parser("import", help="Import a dataset from a local folder") import_parser.add_argument( diff --git a/tests/manual/debugme.py b/tests/manual/debugme.py index b4d37941..5cdd6566 100644 --- a/tests/manual/debugme.py +++ b/tests/manual/debugme.py @@ -5,6 +5,7 @@ os.environ["ROBOFLOW_CONFIG_DIR"] = f"{thisdir}/data/.config" from roboflow.roboflowpy import _argparser # noqa: E402 +from roboflow import Roboflow # import requests # requests.urllib3.disable_warnings() @@ -12,7 +13,8 @@ rootdir = os.path.abspath(f"{thisdir}/../..") sys.path.append(rootdir) -if __name__ == "__main__": + +def run_cli(): parser = _argparser() # args = parser.parse_args(["login"]) # args = parser.parse_args(f"upload {thisdir}/../datasets/chess -w wolfodorpythontests -p chess".split()) # noqa: E501 // docs @@ -45,3 +47,33 @@ # f"import -w tonyprivate -p meh-plvrv {thisdir}/../datasets/paligemma/".split() # noqa: E501 // docs ) args.func(args) + + +def run_api_train(): + rf = Roboflow() + project = rf.workspace("model-evaluation-workspace").project("penguin-finder") + # version_number = project.generate_version( + # settings={ + # "augmentation": { + # "bbblur": {"pixels": 1.5}, + # "image": {"versions": 2}, + # }, + # "preprocessing": { + # "auto-orient": True, + # }, + # } + # ) + version_number = "18" + print(version_number) + version = project.version(version_number) + model = version.train( + speed="fast", # Options: "fast" (default) or "accurate" (paid feature) + checkpoint=None, # Use a specific checkpoint to continue training + modelType="rfdetr-nano", + ) + print(model) + + +if __name__ == "__main__": + # run_cli() + run_api_train() From 744fad10b689510564149e681d9efeb333fb195a Mon Sep 17 00:00:00 2001 From: Rodrigo Barbosa Date: Mon, 11 Aug 2025 10:03:37 -0300 Subject: [PATCH 2/2] Revert train test --- tests/manual/debugme.py | 34 +--------------------------------- 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/tests/manual/debugme.py b/tests/manual/debugme.py index 5cdd6566..b4d37941 100644 --- a/tests/manual/debugme.py +++ b/tests/manual/debugme.py @@ -5,7 +5,6 @@ os.environ["ROBOFLOW_CONFIG_DIR"] = f"{thisdir}/data/.config" from roboflow.roboflowpy import _argparser # noqa: E402 -from roboflow import Roboflow # import requests # requests.urllib3.disable_warnings() @@ -13,8 +12,7 @@ rootdir = os.path.abspath(f"{thisdir}/../..") sys.path.append(rootdir) - -def run_cli(): +if __name__ == "__main__": parser = _argparser() # args = parser.parse_args(["login"]) # args = parser.parse_args(f"upload {thisdir}/../datasets/chess -w wolfodorpythontests -p chess".split()) # noqa: E501 // docs @@ -47,33 +45,3 @@ def run_cli(): # f"import -w tonyprivate -p meh-plvrv {thisdir}/../datasets/paligemma/".split() # noqa: E501 // docs ) args.func(args) - - -def run_api_train(): - rf = Roboflow() - project = rf.workspace("model-evaluation-workspace").project("penguin-finder") - # version_number = project.generate_version( - # settings={ - # "augmentation": { - # "bbblur": {"pixels": 1.5}, - # "image": {"versions": 2}, - # }, - # "preprocessing": { - # "auto-orient": True, - # }, - # } - # ) - version_number = "18" - print(version_number) - version = project.version(version_number) - model = version.train( - speed="fast", # Options: "fast" (default) or "accurate" (paid feature) - checkpoint=None, # Use a specific checkpoint to continue training - modelType="rfdetr-nano", - ) - print(model) - - -if __name__ == "__main__": - # run_cli() - run_api_train()