diff --git a/roboflow/util/image_utils.py b/roboflow/util/image_utils.py index 71a32824..6d556ab3 100644 --- a/roboflow/util/image_utils.py +++ b/roboflow/util/image_utils.py @@ -103,7 +103,10 @@ def load_labelmap(f): with open(f) as file: data = yaml.safe_load(file) if "names" in data: - return {i: name for i, name in enumerate(data["names"])} + names = data["names"] + if isinstance(names, dict): + return {int(i): name for i, name in names.items()} + return {i: name for i, name in enumerate(names)} else: with open(f) as file: lines = [line for line in file.readlines() if line.strip()] diff --git a/tests/util/test_image_utils.py b/tests/util/test_image_utils.py index 5a17fe37..ff4081a3 100644 --- a/tests/util/test_image_utils.py +++ b/tests/util/test_image_utils.py @@ -1,8 +1,10 @@ +import os +import tempfile import unittest import responses -from roboflow.util.image_utils import check_image_path, check_image_url +from roboflow.util.image_utils import check_image_path, check_image_url, load_labelmap class TestCheckImagePath(unittest.TestCase): @@ -33,3 +35,17 @@ def test_url_not_found(self): url = "https://example.com/notfound.png" responses.add(responses.HEAD, url, status=404) self.assertFalse(check_image_url(url)) + + +class TestLoadLabelmap(unittest.TestCase): + def test_yaml_dict_names(self): + with tempfile.NamedTemporaryFile("w+", suffix=".yaml", delete=False) as tmp: + tmp.write("names:\n 0: abc\n 1: def\n") + tmp.flush() + result = load_labelmap(tmp.name) + os.unlink(tmp.name) + self.assertEqual(result, {0: "abc", 1: "def"}) + + def test_yaml_list_names(self): + result = load_labelmap("tests/datasets/sharks-tiny-yolov9/data.yaml") + self.assertEqual(result, {0: "fish", 1: "primary", 2: "shark"})