diff --git a/fastlabel/__init__.py b/fastlabel/__init__.py index f581db8..b0f2e46 100644 --- a/fastlabel/__init__.py +++ b/fastlabel/__init__.py @@ -7,6 +7,8 @@ from concurrent.futures import ThreadPoolExecutor, wait from pathlib import Path from typing import Dict, List, Literal, Optional, Union +from xml.dom import minidom +from xml.etree import ElementTree as ET import cv2 import numpy as np @@ -3548,6 +3550,26 @@ def export_semantic_segmentation( start_index=start_index, ) + def export_cvat( + self, + tasks: list, + output_dir: str = os.path.join("output", "cvat"), + pretty_print: bool = True, + ) -> None: + xml_elements = converters.CvatConverter(logger).tasks_to_cvat(tasks) + + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "annotations.xml") + with open(output_path, "w", encoding="utf-8") as f: + if pretty_print: + doct = minidom.parseString(ET.tostring(xml_elements)) + doct.writexml( + f, encoding="utf-8", indent=" ", newl="\n", addindent=" " + ) + else: + tree = ET.ElementTree(xml_elements) + tree.write(f, encoding="unicode", xml_declaration=True) + def __export_index_color_image( self, task: list, diff --git a/fastlabel/converters.py b/fastlabel/converters.py index 3cec06e..695fd09 100644 --- a/fastlabel/converters.py +++ b/fastlabel/converters.py @@ -5,10 +5,12 @@ from contextlib import contextmanager from datetime import datetime from decimal import Decimal +from logging import Logger from operator import itemgetter from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Dict, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union +from xml.etree import ElementTree as ET import cv2 import geojson @@ -980,6 +982,204 @@ def __remove_duplicated_coordinates(points: List[int]) -> List[int]: return new_points +class CvatConverter: + CVAT_VERSION = "1.1" + _ANNOTATION_BUILDERS = { + "bbox": "_add_bbox", + "polygon": "_add_polygon", + "polyline": "_add_polyline", + "keypoint": "_add_keypoints", + "line": "_add_line", + "segmentation": "_add_segmentation", + } + + def __init__( + self, + logger: Logger, + ) -> None: + self.logger = logger + + def tasks_to_cvat(self, tasks: Iterable[dict]) -> ET.Element: + """Convert FastLabel tasks to CVAT XML format. + + tasks schema (dict): + - task: { + "name": str, + "width": int, + "height": int, + "annotations": [annotation, ...] + } + - annotation: { + "id": str, + "title": str | None, + "type": str, # bbox / polygon / polyline / keypoint / line / segmentation + "value": str, + "points": Any, + "attributes": [{"name": str, "value": Any}, ...], + "rotation": float | None (optional, bbox only) + } + + returns: + - root: ... + + XML outline: + + 1.1 + + + ... + + + + """ + root = ET.Element("annotations") + self._make_tag(root, "version", CvatConverter.CVAT_VERSION) + + for index, task in enumerate(tasks): + image = self._make_tag( + root, + "image", + attrib={ + "id": index, + "name": task["name"].replace("/", "_"), + "width": task["width"], + "height": task["height"], + }, + ) + for annotation in task["annotations"]: + elems: list[ET.Element] = [] + try: + fl_type = annotation["type"] + if fl_type not in self._ANNOTATION_BUILDERS: + raise ValueError( + f"Unsupported fastLabel annotation type: {annotation['type']}" + ) + builder = getattr(self, self._ANNOTATION_BUILDERS[fl_type]) + elems = builder(image, annotation) + + for elem in elems: + for attr in annotation["attributes"]: + self._make_tag( + elem, + "attribute", + attr["value"], + attrib={"name": attr["name"]}, + ) + except Exception as e: + for elem in elems or []: + if elem in image: + image.remove(elem) + + self.logger.error( + "task_name=%s annotation_id=%s annotation_title=%s annotation_type=%s error=%s", + task.get("name"), + annotation.get("id"), + annotation.get("title"), + annotation.get("type"), + e, + ) + continue + + return root + + @staticmethod + def _make_tag( + root: ET.Element, tag: str, value: Any = None, attrib: dict | None = None + ) -> ET.Element: + safe_attrib = { + k: "" if v is None else str(v) for k, v in (attrib or {}).items() + } + elem = ET.SubElement(root, tag, attrib=safe_attrib) + if value is not None: + elem.text = str(value) + return elem + + def _add_points_shape( + self, image_elem: ET.Element, annotation: dict, tag: str + ) -> list[ET.Element]: + points = list(annotation["points"]) + if len(points) % 2 != 0: + raise ValueError( + f"ポイントが偶数ではありません。Annotation({annotation['value']}): {len(points)}" + ) + flatten = [] + while points: + x, y, *points = points + flatten.append(f"{x}, {y}") + elem = self._make_tag( + image_elem, + tag, + attrib={"label": annotation["value"], "points": ";".join(flatten)}, + ) + return [elem] + + def _add_bbox(self, image_elem: ET.Element, annotation: dict) -> list[ET.Element]: + points = annotation["points"] + if len(points) != 4: + raise ValueError("矩形のポイントが 4 つではありません。") + elem = self._make_tag( + image_elem, + "box", + attrib={ + "label": annotation["value"], + "xtl": points[0], + "ytl": points[1], + "xbr": points[2], + "ybr": points[3], + "rotation": annotation.get("rotation", 0), + }, + ) + return [elem] + + def _add_polygon( + self, image_elem: ET.Element, annotation: dict + ) -> list[ET.Element]: + return self._add_points_shape(image_elem, annotation, "polygon") + + def _add_polyline( + self, image_elem: ET.Element, annotation: dict + ) -> list[ET.Element]: + return self._add_points_shape(image_elem, annotation, "polyline") + + def _add_keypoints( + self, image_elem: ET.Element, annotation: dict + ) -> list[ET.Element]: + return self._add_points_shape(image_elem, annotation, "points") + + def _add_line(self, image_elem: ET.Element, annotation: dict) -> list[ET.Element]: + return self._add_points_shape(image_elem, annotation, "polyline") + + def _add_segmentation( + self, image_elem: ET.Element, annotation: dict + ) -> list[ET.Element]: + polygons = [] + points = annotation["points"] + if not points: + raise ValueError("セグメンテーションのポイント数が 0 です。") + try: + points[0][0][0] + except IndexError as exc: + raise ValueError("セグメンテーションが3次元配列ではありません。") from exc + for segment in points: + for region in segment: + flatten = [] + region_points = list(region) + while region_points: + x, y, *region_points = region_points + flatten.append(f"{x}, {y}") + polygons.append( + self._make_tag( + image_elem, + "polygon", + attrib={ + "label": annotation["value"], + "points": ";".join(flatten), + }, + ) + ) + return polygons + + def get_pixel_coordinates(points: List[Union[int, float]]) -> List[int]: """ Remove diagonal coordinates and return pixel outline coordinates.