diff --git a/roboflow/core/project.py b/roboflow/core/project.py index 00ba0d26..c451d908 100644 --- a/roboflow/core/project.py +++ b/roboflow/core/project.py @@ -886,3 +886,68 @@ def create_annotation_job( raise RuntimeError(f"Failed to create annotation job: {response.text}") return response.json() + + def get_batches(self) -> Dict: + """ + Get a list of all batches in the project. + + Returns: + Dict: A dictionary containing the list of batches + + Example: + >>> import roboflow + + >>> rf = roboflow.Roboflow(api_key="YOUR_API_KEY") + + >>> project = rf.workspace().project("PROJECT_ID") + + >>> batches = project.get_batches() + """ + url = f"{API_URL}/{self.__workspace}/{self.__project_name}/batches?api_key={self.__api_key}" + + response = requests.get(url) + + if response.status_code != 200: + try: + error_data = response.json() + if "error" in error_data: + raise RuntimeError(error_data["error"]) + raise RuntimeError(response.text) + except ValueError: + raise RuntimeError(f"Failed to get batches: {response.text}") + + return response.json() + + def get_batch(self, batch_id: str) -> Dict: + """ + Get information for a specific batch in the project. + + Args: + batch_id (str): The ID of the batch to retrieve + + Returns: + Dict: A dictionary containing the batch details + + Example: + >>> import roboflow + + >>> rf = roboflow.Roboflow(api_key="YOUR_API_KEY") + + >>> project = rf.workspace().project("PROJECT_ID") + + >>> batch = project.get_batch("batch123") + """ + url = f"{API_URL}/{self.__workspace}/{self.__project_name}/batches/{batch_id}?api_key={self.__api_key}" + + response = requests.get(url) + + if response.status_code != 200: + try: + error_data = response.json() + if "error" in error_data: + raise RuntimeError(error_data["error"]) + raise RuntimeError(response.text) + except ValueError: + raise RuntimeError(f"Failed to get batch {batch_id}: {response.text}") + + return response.json() diff --git a/tests/test_project.py b/tests/test_project.py index 353fa836..edd81772 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -455,6 +455,89 @@ def test_project_upload_dataset(self): for mock in mocks.values(): mock.stop() + def test_get_batches_success(self): + expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches?api_key={ROBOFLOW_API_KEY}" + mock_response = { + "batches": [ + { + "name": "Uploaded on 11/22/22 at 1:39 pm", + "numJobs": 2, + "images": 115, + "uploaded": {"_seconds": 1669146024, "_nanoseconds": 818000000}, + "id": "batch-1", + }, + { + "numJobs": 0, + "images": 11, + "uploaded": {"_seconds": 1669236873, "_nanoseconds": 47000000}, + "name": "Upload via API", + "id": "batch-2", + }, + ] + } + + responses.add(responses.GET, expected_url, json=mock_response, status=200) + + batches = self.project.get_batches() + + self.assertIsInstance(batches, dict) + self.assertIn("batches", batches) + self.assertEqual(len(batches["batches"]), 2) + self.assertEqual(batches["batches"][0]["id"], "batch-1") + self.assertEqual(batches["batches"][0]["name"], "Uploaded on 11/22/22 at 1:39 pm") + self.assertEqual(batches["batches"][0]["images"], 115) + self.assertEqual(batches["batches"][0]["numJobs"], 2) + self.assertEqual(batches["batches"][1]["id"], "batch-2") + self.assertEqual(batches["batches"][1]["name"], "Upload via API") + + def test_get_batches_error(self): + expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches?api_key={ROBOFLOW_API_KEY}" + error_response = {"error": "Cannot retrieve batches"} + + responses.add(responses.GET, expected_url, json=error_response, status=404) + + with self.assertRaises(RuntimeError) as context: + self.project.get_batches() + + self.assertEqual(str(context.exception), "Cannot retrieve batches") + + def test_get_batch_success(self): + batch_id = "batch-123" + expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches/{batch_id}?api_key={ROBOFLOW_API_KEY}" + mock_response = { + "batch": { + "name": "Uploaded on 11/22/22 at 1:39 pm", + "numJobs": 2, + "images": 115, + "uploaded": {"_seconds": 1669146024, "_nanoseconds": 818000000}, + "id": batch_id, + } + } + + responses.add(responses.GET, expected_url, json=mock_response, status=200) + + batch = self.project.get_batch(batch_id) + + self.assertIsInstance(batch, dict) + self.assertIn("batch", batch) + self.assertEqual(batch["batch"]["id"], batch_id) + self.assertEqual(batch["batch"]["name"], "Uploaded on 11/22/22 at 1:39 pm") + self.assertEqual(batch["batch"]["images"], 115) + self.assertEqual(batch["batch"]["numJobs"], 2) + self.assertIn("uploaded", batch["batch"]) + + def test_get_batch_error(self): + batch_id = "nonexistent-batch" + expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches/{batch_id}?api_key={ROBOFLOW_API_KEY}" + error_response = {"error": "Batch not found"} + + responses.add(responses.GET, expected_url, json=error_response, status=404) + + with self.assertRaises(RuntimeError) as context: + self.project.get_batch(batch_id) + + self.assertEqual(str(context.exception), "Batch not found") + def test_classification_dataset_upload(self): from roboflow.util import folderparser