diff --git a/fastlabel/__init__.py b/fastlabel/__init__.py
index f581db8..b0f2e46 100644
--- a/fastlabel/__init__.py
+++ b/fastlabel/__init__.py
@@ -7,6 +7,8 @@
from concurrent.futures import ThreadPoolExecutor, wait
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
+from xml.dom import minidom
+from xml.etree import ElementTree as ET
import cv2
import numpy as np
@@ -3548,6 +3550,26 @@ def export_semantic_segmentation(
start_index=start_index,
)
+ def export_cvat(
+ self,
+ tasks: list,
+ output_dir: str = os.path.join("output", "cvat"),
+ pretty_print: bool = True,
+ ) -> None:
+ xml_elements = converters.CvatConverter(logger).tasks_to_cvat(tasks)
+
+ os.makedirs(output_dir, exist_ok=True)
+ output_path = os.path.join(output_dir, "annotations.xml")
+ with open(output_path, "w", encoding="utf-8") as f:
+ if pretty_print:
+ doct = minidom.parseString(ET.tostring(xml_elements))
+ doct.writexml(
+ f, encoding="utf-8", indent=" ", newl="\n", addindent=" "
+ )
+ else:
+ tree = ET.ElementTree(xml_elements)
+ tree.write(f, encoding="unicode", xml_declaration=True)
+
def __export_index_color_image(
self,
task: list,
diff --git a/fastlabel/converters.py b/fastlabel/converters.py
index 3cec06e..695fd09 100644
--- a/fastlabel/converters.py
+++ b/fastlabel/converters.py
@@ -5,10 +5,12 @@
from contextlib import contextmanager
from datetime import datetime
from decimal import Decimal
+from logging import Logger
from operator import itemgetter
from pathlib import Path
from tempfile import NamedTemporaryFile
-from typing import Dict, List, Optional, Union
+from typing import Any, Dict, Iterable, List, Optional, Union
+from xml.etree import ElementTree as ET
import cv2
import geojson
@@ -980,6 +982,204 @@ def __remove_duplicated_coordinates(points: List[int]) -> List[int]:
return new_points
+class CvatConverter:
+ CVAT_VERSION = "1.1"
+ _ANNOTATION_BUILDERS = {
+ "bbox": "_add_bbox",
+ "polygon": "_add_polygon",
+ "polyline": "_add_polyline",
+ "keypoint": "_add_keypoints",
+ "line": "_add_line",
+ "segmentation": "_add_segmentation",
+ }
+
+ def __init__(
+ self,
+ logger: Logger,
+ ) -> None:
+ self.logger = logger
+
+ def tasks_to_cvat(self, tasks: Iterable[dict]) -> ET.Element:
+ """Convert FastLabel tasks to CVAT XML format.
+
+ tasks schema (dict):
+ - task: {
+ "name": str,
+ "width": int,
+ "height": int,
+ "annotations": [annotation, ...]
+ }
+ - annotation: {
+ "id": str,
+ "title": str | None,
+ "type": str, # bbox / polygon / polyline / keypoint / line / segmentation
+ "value": str,
+ "points": Any,
+ "attributes": [{"name": str, "value": Any}, ...],
+ "rotation": float | None (optional, bbox only)
+ }
+
+ returns:
+ - root: ...
+
+ XML outline:
+
+ 1.1
+
+
+ ...
+
+
+
+ """
+ root = ET.Element("annotations")
+ self._make_tag(root, "version", CvatConverter.CVAT_VERSION)
+
+ for index, task in enumerate(tasks):
+ image = self._make_tag(
+ root,
+ "image",
+ attrib={
+ "id": index,
+ "name": task["name"].replace("/", "_"),
+ "width": task["width"],
+ "height": task["height"],
+ },
+ )
+ for annotation in task["annotations"]:
+ elems: list[ET.Element] = []
+ try:
+ fl_type = annotation["type"]
+ if fl_type not in self._ANNOTATION_BUILDERS:
+ raise ValueError(
+ f"Unsupported fastLabel annotation type: {annotation['type']}"
+ )
+ builder = getattr(self, self._ANNOTATION_BUILDERS[fl_type])
+ elems = builder(image, annotation)
+
+ for elem in elems:
+ for attr in annotation["attributes"]:
+ self._make_tag(
+ elem,
+ "attribute",
+ attr["value"],
+ attrib={"name": attr["name"]},
+ )
+ except Exception as e:
+ for elem in elems or []:
+ if elem in image:
+ image.remove(elem)
+
+ self.logger.error(
+ "task_name=%s annotation_id=%s annotation_title=%s annotation_type=%s error=%s",
+ task.get("name"),
+ annotation.get("id"),
+ annotation.get("title"),
+ annotation.get("type"),
+ e,
+ )
+ continue
+
+ return root
+
+ @staticmethod
+ def _make_tag(
+ root: ET.Element, tag: str, value: Any = None, attrib: dict | None = None
+ ) -> ET.Element:
+ safe_attrib = {
+ k: "" if v is None else str(v) for k, v in (attrib or {}).items()
+ }
+ elem = ET.SubElement(root, tag, attrib=safe_attrib)
+ if value is not None:
+ elem.text = str(value)
+ return elem
+
+ def _add_points_shape(
+ self, image_elem: ET.Element, annotation: dict, tag: str
+ ) -> list[ET.Element]:
+ points = list(annotation["points"])
+ if len(points) % 2 != 0:
+ raise ValueError(
+ f"ポイントが偶数ではありません。Annotation({annotation['value']}): {len(points)}"
+ )
+ flatten = []
+ while points:
+ x, y, *points = points
+ flatten.append(f"{x}, {y}")
+ elem = self._make_tag(
+ image_elem,
+ tag,
+ attrib={"label": annotation["value"], "points": ";".join(flatten)},
+ )
+ return [elem]
+
+ def _add_bbox(self, image_elem: ET.Element, annotation: dict) -> list[ET.Element]:
+ points = annotation["points"]
+ if len(points) != 4:
+ raise ValueError("矩形のポイントが 4 つではありません。")
+ elem = self._make_tag(
+ image_elem,
+ "box",
+ attrib={
+ "label": annotation["value"],
+ "xtl": points[0],
+ "ytl": points[1],
+ "xbr": points[2],
+ "ybr": points[3],
+ "rotation": annotation.get("rotation", 0),
+ },
+ )
+ return [elem]
+
+ def _add_polygon(
+ self, image_elem: ET.Element, annotation: dict
+ ) -> list[ET.Element]:
+ return self._add_points_shape(image_elem, annotation, "polygon")
+
+ def _add_polyline(
+ self, image_elem: ET.Element, annotation: dict
+ ) -> list[ET.Element]:
+ return self._add_points_shape(image_elem, annotation, "polyline")
+
+ def _add_keypoints(
+ self, image_elem: ET.Element, annotation: dict
+ ) -> list[ET.Element]:
+ return self._add_points_shape(image_elem, annotation, "points")
+
+ def _add_line(self, image_elem: ET.Element, annotation: dict) -> list[ET.Element]:
+ return self._add_points_shape(image_elem, annotation, "polyline")
+
+ def _add_segmentation(
+ self, image_elem: ET.Element, annotation: dict
+ ) -> list[ET.Element]:
+ polygons = []
+ points = annotation["points"]
+ if not points:
+ raise ValueError("セグメンテーションのポイント数が 0 です。")
+ try:
+ points[0][0][0]
+ except IndexError as exc:
+ raise ValueError("セグメンテーションが3次元配列ではありません。") from exc
+ for segment in points:
+ for region in segment:
+ flatten = []
+ region_points = list(region)
+ while region_points:
+ x, y, *region_points = region_points
+ flatten.append(f"{x}, {y}")
+ polygons.append(
+ self._make_tag(
+ image_elem,
+ "polygon",
+ attrib={
+ "label": annotation["value"],
+ "points": ";".join(flatten),
+ },
+ )
+ )
+ return polygons
+
+
def get_pixel_coordinates(points: List[Union[int, float]]) -> List[int]:
"""
Remove diagonal coordinates and return pixel outline coordinates.