diff --git a/scripts/upload_to_hf.py b/scripts/upload_to_hf.py index 3c590e2..d255411 100644 --- a/scripts/upload_to_hf.py +++ b/scripts/upload_to_hf.py @@ -15,7 +15,6 @@ from __future__ import annotations import argparse -import csv import json import logging import sys @@ -30,49 +29,6 @@ DEFAULT_DATASET_ROOT = REPO_ROOT / "data" / "gdb-dataset" HF_REPO_ID = "lica-world/GDB" -SKIP_BENCHMARKS = set() - -# Benchmarks whose load_data() is too slow (image compositing, alpha checks) -# and should be loaded directly from their manifest CSVs instead. -MANIFEST_BENCHMARKS = { - "layout-1": { - "csv": "layout2_manifest.csv", - "prompt_key": "prompt", - "gt_key": "source_layout", - "image_key": "reference_image", - }, - "layout-2": { - "csv": "layout_single_manifest.csv", - "prompt_key": "prompt", - "gt_key": "ground_truth_image", - "image_key": "input_composite", - }, - "layout-3": { - "csv": "g4_firestore_image_gen_pairs_manifest.filtered_component_renders.csv", - "prompt_key": None, - "gt_key": None, - "image_key": "a_image_path", - }, - "layout-8": { - "csv": "g15_object_insertion_manifest.csv", - "prompt_key": "prompt", - "gt_key": "ground_truth_image", - "image_key": "masked_layout", - }, - "typography-7": { - "csv": "g10_text_element_manifest.csv", - "prompt_key": "prompt", - "gt_key": "ground_truth_image", - "image_key": "input_image", - }, - "typography-8": { - "csv": "g10_text_inpaint_manifest.csv", - "prompt_key": "prompt", - "gt_key": "ground_truth_image", - "image_key": "input_image", - }, -} - def _serialize(value: Any) -> str: if isinstance(value, str): @@ -98,202 +54,17 @@ def _is_video(path: str) -> bool: return path.lower().endswith(".mp4") -def _read_csv(csv_path: Path) -> List[Dict[str, str]]: - with open(csv_path, "r", encoding="utf-8") as f: - return list(csv.DictReader(f)) - - -def _read_json(json_path: Path) -> Any: - with open(json_path, encoding="utf-8") as f: - return json.load(f) - - -def load_csv_benchmark( - benchmark_id: str, - meta: Any, - data_dir: Path, - dataset_root: Path, -) -> List[Dict[str, Any]]: - csv_path = data_dir / "samples.csv" - if not csv_path.exists(): - raise FileNotFoundError(f"samples.csv not found in {data_dir}") - - rows_out = [] - base = dataset_root.resolve() - - for row in _read_csv(csv_path): - img_rel = row.get("image_path", "") - img_abs = str((base / img_rel).resolve()) if img_rel else "" - is_vid = _is_video(img_abs) if img_abs else False - - has_image = bool(img_abs and not is_vid - and Path(img_abs).exists() and _is_image_file(img_abs)) - - extra = {k: v for k, v in row.items() - if k not in ("sample_id", "prompt", "image_path", "expected_output")} - - rows_out.append({ - "sample_id": row.get("sample_id", ""), - "benchmark_id": benchmark_id, - "domain": meta.domain, - "task_type": meta.task_type.value, - "benchmark_name": meta.name, - "prompt": row.get("prompt", ""), - "ground_truth": row.get("expected_output", ""), - "image": img_abs if has_image else None, - "media_path": img_rel, - "media_type": "video" if is_vid else ("image" if has_image else "none"), - "metadata": json.dumps(extra, ensure_ascii=False) if extra else "{}", - }) - - return rows_out - - -JSON_FIELD_MAP = { - "svg-1": {"gt_key": "answer", "extra": ["svg_code", "question", "options"]}, - "svg-2": {"gt_key": "answer", "extra": ["svg_code", "question", "options"]}, - "svg-3": {"gt_key": "fixed_svg", "extra": ["bug_svg", "error_type", "difficulty"]}, - "svg-4": {"gt_key": None, "extra": ["origin_svg", "opti_ratio"]}, - "svg-5": {"gt_key": "answer", "extra": ["original_svg", "command"]}, - "svg-6": {"gt_key": None, "extra": ["description"]}, - "svg-7": {"gt_key": None, "extra": ["description"]}, - "svg-8": {"gt_key": None, "extra": ["description"]}, - "lottie-1": {"gt_key": None, "extra": ["description"]}, - "lottie-2": {"gt_key": None, "extra": ["description"]}, - "template-1": {"gt_key": "label", "extra": []}, - "template-2": {"gt_key": None, "extra": []}, - "template-3": {"gt_key": None, "extra": ["n_clusters"]}, - "template-4": {"gt_key": None, "extra": []}, - "template-5": {"gt_key": None, "extra": ["difficulty"]}, -} - - -def load_json_benchmark( - benchmark_id: str, - meta: Any, - data_dir: Path, - dataset_root: Path, -) -> List[Dict[str, Any]]: - json_path = data_dir / f"{benchmark_id}.json" - if not json_path.exists(): - raise FileNotFoundError(f"{benchmark_id}.json not found in {data_dir}") - - data = _read_json(json_path) - base = dataset_root.resolve() - rows_out = [] - - if isinstance(data, list): - items = data - else: - for key in ("samples", "items", "pairs", "queries", "problems"): - if key in data: - items = data[key] - break - else: - items = [data] - - for item in items: - sid = str(item.get("id", item.get("sample_id", ""))) - - gt = item.get("answer", item.get("ground_truth", item.get("label", ""))) - if isinstance(gt, (dict, list)): - gt = json.dumps(gt, ensure_ascii=False) - else: - gt = str(gt) if gt is not None else "" - - prompt = item.get("question", item.get("prompt", item.get("description", ""))) - if isinstance(prompt, (dict, list)): - prompt = json.dumps(prompt, ensure_ascii=False) - - img_rel = item.get("image_path", item.get("image", "")) - img_abs = "" - if img_rel and isinstance(img_rel, str): - candidate = base / img_rel - if candidate.exists(): - img_abs = str(candidate.resolve()) - else: - candidate2 = data_dir / img_rel - if candidate2.exists(): - img_abs = str(candidate2.resolve()) - - is_vid = _is_video(img_abs) if img_abs else False - has_image = bool(img_abs and not is_vid - and Path(img_abs).exists() and _is_image_file(img_abs)) - - skip_keys = {"id", "sample_id", "answer", "ground_truth", "label", - "question", "prompt", "description", "image_path", "image"} - extra = {k: v for k, v in item.items() if k not in skip_keys} - - rows_out.append({ - "sample_id": sid, - "benchmark_id": benchmark_id, - "domain": meta.domain, - "task_type": meta.task_type.value, - "benchmark_name": meta.name, - "prompt": str(prompt) if prompt else "", - "ground_truth": gt, - "image": img_abs if has_image else None, - "media_path": str(img_rel) if img_rel else "", - "media_type": "video" if is_vid else ("image" if has_image else "none"), - "metadata": json.dumps(extra, ensure_ascii=False, default=str) if extra else "{}", - }) - - return rows_out - - -def load_manifest_benchmark( - benchmark_id: str, - meta: Any, - data_dir: Path, - dataset_root: Path, -) -> List[Dict[str, Any]]: - spec = MANIFEST_BENCHMARKS[benchmark_id] - csv_path = data_dir / spec["csv"] - if not csv_path.exists(): - raise FileNotFoundError(f"{spec['csv']} not found in {data_dir}") - - base = dataset_root.resolve() - prompt_key = spec["prompt_key"] - gt_key = spec["gt_key"] - image_key = spec["image_key"] - skip_keys = {"sample_id", prompt_key, gt_key, image_key} - {None} - - rows_out = [] - for row in _read_csv(csv_path): - sid = row.get("sample_id", row.get("pair_id", "")) - prompt = row.get(prompt_key, "") if prompt_key else "" - gt_raw = row.get(gt_key, "") if gt_key else "" - - img_rel = row.get(image_key, "") if image_key else "" - img_abs = "" - if img_rel: - for candidate_base in [data_dir, base]: - candidate = candidate_base / img_rel - if candidate.exists(): - img_abs = str(candidate.resolve()) - break - - is_vid = _is_video(img_abs) if img_abs else False - has_image = bool(img_abs and not is_vid - and Path(img_abs).exists() and _is_image_file(img_abs)) - - extra = {k: v for k, v in row.items() if k not in skip_keys} - - rows_out.append({ - "sample_id": sid, - "benchmark_id": benchmark_id, - "domain": meta.domain, - "task_type": meta.task_type.value, - "benchmark_name": meta.name, - "prompt": prompt, - "ground_truth": gt_raw, - "image": img_abs if has_image else None, - "media_path": img_rel, - "media_type": "video" if is_vid else ("image" if has_image else "none"), - "metadata": json.dumps(extra, ensure_ascii=False, default=str) if extra else "{}", - }) - - return rows_out +def _normalize_paths(value: Any, dataset_root_str: str) -> Any: + """Strip ``dataset_root`` prefix from absolute paths so parquet is portable.""" + if isinstance(value, str): + if value.startswith(dataset_root_str + "/"): + return value[len(dataset_root_str) + 1:] + return value + if isinstance(value, list): + return [_normalize_paths(v, dataset_root_str) for v in value] + if isinstance(value, dict): + return {k: _normalize_paths(v, dataset_root_str) for k, v in value.items()} + return value def load_via_registry( @@ -304,18 +75,29 @@ def load_via_registry( dataset_root: Path, ) -> List[Dict[str, Any]]: bench = registry.get(benchmark_id) + logger.info(" %s: calling bench.load_data()…", benchmark_id) + t0 = time.time() samples = bench.load_data(data_dir, dataset_root=str(dataset_root)) + logger.info(" %s: load_data produced %d samples in %.1fs", + benchmark_id, len(samples), time.time() - t0) + dataset_root_str = str(Path(dataset_root).resolve()) rows_out = [] - for sample in samples: + # Only fields promoted to their own parquet column are excluded from + # ``metadata``; everything else (including path-valued keys) must survive. + metadata_skip = {"sample_id", "ground_truth", "prompt"} + for i, sample in enumerate(samples): + if i and i % 500 == 0: + logger.info(" %s: packed %d/%d rows", benchmark_id, i, len(samples)) img_path = _find_image(sample) is_vid = _is_video(img_path) if img_path else False has_image = bool(img_path and not is_vid and Path(img_path).exists() and _is_image_file(img_path)) - skip = {"sample_id", "ground_truth", "prompt", "image_path", - "input_image", "input_composite", "source_image", "video_path"} - extra = {k: v for k, v in sample.items() if k not in skip} + extra = {k: _normalize_paths(v, dataset_root_str) + for k, v in sample.items() if k not in metadata_skip} + + media_path_rel = _normalize_paths(img_path, dataset_root_str) if img_path else "" rows_out.append({ "sample_id": str(sample.get("sample_id", "")), @@ -326,7 +108,7 @@ def load_via_registry( "prompt": sample.get("prompt", ""), "ground_truth": _serialize(sample.get("ground_truth", "")), "image": img_path if has_image else None, - "media_path": img_path or "", + "media_path": media_path_rel, "media_type": "video" if is_vid else ("image" if has_image else "none"), "metadata": json.dumps(extra, ensure_ascii=False, default=str) if extra else "{}", }) @@ -339,38 +121,21 @@ def load_benchmark( benchmark_id: str, dataset_root: Path, ) -> List[Dict[str, Any]]: - if benchmark_id in SKIP_BENCHMARKS: - logger.info("Skipping %s (excluded)", benchmark_id) - return [] - bench = registry.get(benchmark_id) - meta = bench.meta - try: data_dir = bench.resolve_data_dir(dataset_root) except FileNotFoundError as exc: logger.warning("Skipping %s: %s", benchmark_id, exc) return [] - csv_path = data_dir / "samples.csv" - json_path = data_dir / f"{benchmark_id}.json" - t0 = time.time() try: - if benchmark_id in MANIFEST_BENCHMARKS: - rows = load_manifest_benchmark(benchmark_id, meta, data_dir, dataset_root) - elif csv_path.exists(): - rows = load_csv_benchmark(benchmark_id, meta, data_dir, dataset_root) - elif json_path.exists(): - rows = load_json_benchmark(benchmark_id, meta, data_dir, dataset_root) - else: - rows = load_via_registry(registry, benchmark_id, meta, data_dir, dataset_root) + rows = load_via_registry(registry, benchmark_id, bench.meta, data_dir, dataset_root) except Exception as exc: logger.warning("Failed to load %s: %s: %s", benchmark_id, type(exc).__name__, exc) return [] - dt = time.time() - t0 - logger.info("Loaded %s: %d samples (%.1fs)", benchmark_id, len(rows), dt) + logger.info("Loaded %s: %d samples (%.1fs)", benchmark_id, len(rows), time.time() - t0) return rows @@ -403,6 +168,32 @@ def build_dataset(all_rows: List[Dict[str, Any]]): return datasets.Dataset.from_list(all_rows, features=features) +_BENCHMARK_ID_PATTERN = r"^[a-z]+-\d+$" + + +def _merge_card_configs(repo_id: str, new_configs: List[str]) -> List[str]: + """Union new configs with any already on the Hub, so partial uploads don't + drop existing declarations.""" + import re + + from huggingface_hub import hf_hub_download + + try: + existing_readme = hf_hub_download( + repo_id=repo_id, repo_type="dataset", filename="README.md", + ) + except Exception: + return sorted(set(new_configs)) + + content = Path(existing_readme).read_text(encoding="utf-8") + existing = set() + for match in re.finditer(r"- config_name:\s*([^\s]+)", content): + name = match.group(1).strip() + if re.match(_BENCHMARK_ID_PATTERN, name): + existing.add(name) + return sorted(existing | set(new_configs)) + + def generate_dataset_card(config_names: Optional[List[str]] = None) -> str: if config_names is None: config_names = ["all"] @@ -580,12 +371,13 @@ def main(): commit_message=f"Upload GDB benchmark: {bid}") logger.info("Uploading dataset card...") + card_config_names = _merge_card_configs(args.repo_id, sorted(per_benchmark.keys())) api.upload_file( - path_or_fileobj=generate_dataset_card(sorted(per_benchmark.keys())).encode("utf-8"), + path_or_fileobj=generate_dataset_card(card_config_names).encode("utf-8"), path_in_repo="README.md", repo_id=args.repo_id, repo_type="dataset", - commit_message="Add dataset card", + commit_message="Update dataset card", ) logger.info("Done! https://huggingface.co/datasets/%s", args.repo_id) diff --git a/src/gdb/hf.py b/src/gdb/hf.py index 76ccb9d..be38dfb 100644 --- a/src/gdb/hf.py +++ b/src/gdb/hf.py @@ -1,11 +1,7 @@ """Load benchmark samples from the HuggingFace Hub dataset (lica-world/GDB). -This module provides a drop-in alternative to the local file-based -``load_data()`` path. When ``--dataset-root`` is not provided, the runner -can call ``load_from_hub()`` to fetch data directly from HuggingFace. - -Images are cached to disk so that ``build_model_input()`` gets file paths -it can pass to model APIs, matching the local-file contract. +Used when the runner has no ``--dataset-root``. Images are cached to disk +so ``build_model_input()`` receives file paths, matching the local contract. """ from __future__ import annotations @@ -69,15 +65,12 @@ def load_from_hub( repo_id: str = HF_REPO_ID, cache_dir: Optional[Path] = None, ) -> List[Dict[str, Any]]: - """Load samples for *benchmark_id* from the HuggingFace Hub dataset. - - Returns a list of dicts matching the contract of - ``BaseBenchmark.load_data()`` — at minimum ``sample_id`` and - ``ground_truth``, plus task-specific fields unpacked from the - ``metadata`` column. + """Load samples for *benchmark_id* from the HuggingFace Hub. - Images are saved to *cache_dir* (default ``~/.cache/gdb/images/``) - so downstream code receives file path strings, not PIL objects. + Matches ``BaseBenchmark.load_data()``: returns dicts with ``sample_id``, + ``ground_truth``, plus any fields unpacked from the ``metadata`` column. + Images are cached under *cache_dir* (default ``~/.cache/gdb/images/``) + and surfaced as ``image_path`` strings. """ try: from datasets import load_dataset @@ -101,11 +94,9 @@ def load_from_hub( "prompt": row.get("prompt", ""), } - # Alias prompt into keys that some benchmarks expect sample["question"] = sample["prompt"] sample["description"] = sample["prompt"] - # Unpack task-specific fields from metadata JSON meta_raw = row.get("metadata", "{}") try: extra = json.loads(meta_raw) if meta_raw else {} @@ -116,9 +107,11 @@ def load_from_hub( if k not in sample: sample[k] = v - # Handle image: save PIL to cache, store path + # Generation-only configs type ``image`` as Value("string") and store "", + # so we check for a PIL-like object rather than truthiness. pil_img = row.get("image") - if pil_img is not None: + has_pil = pil_img is not None and hasattr(pil_img, "save") + if has_pil: dest = _image_cache_path(cache_dir, benchmark_id, sample["sample_id"]) if dest.exists(): img_path = str(dest)