diff --git a/doc/code/datasets/1_loading_datasets.ipynb b/doc/code/datasets/1_loading_datasets.ipynb index b09f5ad1c..276e98cb0 100644 --- a/doc/code/datasets/1_loading_datasets.ipynb +++ b/doc/code/datasets/1_loading_datasets.ipynb @@ -38,6 +38,7 @@ "TDC23 [@mazeika2023tdc],\n", "ToxicChat [@lin2023toxicchat],\n", "VLSU [@palaskar2025vlsu],\n", + "VLGuard [@zong2024vlguard],\n", "XSTest [@rottger2023xstest],\n", "AILuminate [@vidgen2024ailuminate],\n", "Transphobia Awareness [@scheuerman2025transphobia],\n", diff --git a/doc/code/datasets/1_loading_datasets.py b/doc/code/datasets/1_loading_datasets.py index 5a36e2d00..dd4049e33 100644 --- a/doc/code/datasets/1_loading_datasets.py +++ b/doc/code/datasets/1_loading_datasets.py @@ -42,6 +42,7 @@ # TDC23 [@mazeika2023tdc], # ToxicChat [@lin2023toxicchat], # VLSU [@palaskar2025vlsu], +# VLGuard [@zong2024vlguard], # XSTest [@rottger2023xstest], # AILuminate [@vidgen2024ailuminate], # Transphobia Awareness [@scheuerman2025transphobia], diff --git a/doc/references.bib b/doc/references.bib index 0a2cebc18..7fc1c7689 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -519,6 +519,14 @@ @article{rottger2023xstest url = {https://arxiv.org/abs/2308.01263}, } +@article{zong2024vlguard, + title = {Safety Fine-Tuning at (Almost) No Cost: A Baseline for Vision Large Language Models}, + author = {Yongshuo Zong and Ondrej Bohdal and Tingyang Yu and Yongxin Yang and Timothy Hospedales}, + journal = {arXiv preprint arXiv:2402.02207}, + year = {2024}, + url = {https://arxiv.org/abs/2402.02207}, +} + @article{vidgen2024ailuminate, title = {Introducing v0.5 of the {AI} Safety Benchmark from {MLCommons}}, author = {Bertie Vidgen and Adarsh Agrawal and Ahmed M. Ahmed and Victor Akinwande and Namir Al-Nuaimi and others}, diff --git a/pyrit/datasets/seed_datasets/remote/__init__.py b/pyrit/datasets/seed_datasets/remote/__init__.py index fa96d0ac4..0e3c230bf 100644 --- a/pyrit/datasets/seed_datasets/remote/__init__.py +++ b/pyrit/datasets/seed_datasets/remote/__init__.py @@ -111,6 +111,12 @@ VisualLeakBenchPIIType, _VisualLeakBenchDataset, ) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.vlguard_dataset import ( + VLGuardCategory, + VLGuardSubcategory, + VLGuardSubset, + _VLGuardDataset, +) # noqa: F401 from pyrit.datasets.seed_datasets.remote.vlsu_multimodal_dataset import ( _VLSUMultimodalDataset, ) # noqa: F401 @@ -121,6 +127,9 @@ __all__ = [ "PromptIntelCategory", "PromptIntelSeverity", + "VLGuardCategory", + "VLGuardSubcategory", + "VLGuardSubset", "_AegisContentSafetyDataset", "_AyaRedteamingDataset", "_BabelscapeAlertDataset", @@ -156,6 +165,7 @@ "_TDC23RedteamingDataset", "_ToxicChatDataset", "_TransphobiaAwarenessDataset", + "_VLGuardDataset", "_VLSUMultimodalDataset", "_VisualLeakBenchDataset", "VisualLeakBenchCategory", diff --git a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py new file mode 100644 index 000000000..db0f7aa76 --- /dev/null +++ b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py @@ -0,0 +1,312 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import json +import logging +import uuid +import zipfile +from enum import Enum +from pathlib import Path + +from pyrit.common.path import DB_DATA_PATH +from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( + _RemoteDatasetLoader, +) +from pyrit.models import SeedDataset, SeedPrompt + +logger = logging.getLogger(__name__) + +_HF_REPO_ID = "ys-zong/VLGuard" + + +class VLGuardCategory(Enum): + """ + Categories in the VLGuard dataset. + + PRIVACY: Content involving personal data exposure or surveillance (e.g., reading IDs, tracking individuals). + RISKY_BEHAVIOR: Content depicting or encouraging dangerous activities (e.g., violence, professional advice). + DECEPTION: Content related to misleading or false information (e.g., disinformation, political manipulation). + DISCRIMINATION: Content targeting groups based on identity (e.g., discrimination by sex or race). + """ + + PRIVACY = "privacy" + RISKY_BEHAVIOR = "risky behavior" + DECEPTION = "deception" + DISCRIMINATION = "discrimination" + + +class VLGuardSubcategory(Enum): + """Subcategories in the VLGuard dataset, nested under the main categories.""" + + PERSONAL_DATA = "personal data" + PROFESSIONAL_ADVICE = "professional advice" + POLITICAL = "political" + SEXUALLY_EXPLICIT = "sexually explicit" + VIOLENCE = "violence" + DISINFORMATION = "disinformation" + SEX = "sex" + RACE = "race" + OTHER = "other" + + +class VLGuardSubset(Enum): + """ + Evaluation subsets in the VLGuard dataset. + + UNSAFES: Unsafe images with instructions — tests whether the model refuses unsafe visual content. + SAFE_UNSAFES: Safe images with unsafe instructions — tests whether the model refuses unsafe text prompts. + SAFE_SAFES: Safe images with safe instructions — tests whether the model remains helpful. + """ + + UNSAFES = "unsafes" + SAFE_UNSAFES = "safe_unsafes" + SAFE_SAFES = "safe_safes" + + +class _VLGuardDataset(_RemoteDatasetLoader): + """ + Loader for the VLGuard multimodal safety dataset. + + VLGuard contains image-instruction pairs for evaluating vision-language model safety. + It includes both unsafe and safe images paired with various instructions to test whether + models refuse unsafe content while remaining helpful on safe content. + + The dataset covers 4 categories (privacy, risky behavior, deception, discrimination) + with 9 subcategories (personal data, professional advice, political, sexually explicit, + violence, disinformation, sex, race, other). + + Note: This is a gated dataset on HuggingFace. You must accept the terms at + https://huggingface.co/datasets/ys-zong/VLGuard before use, and provide + a HuggingFace token. + + Reference: https://arxiv.org/abs/2402.02207 + Paper: Safety Fine-Tuning at (Almost) No Cost: A Baseline for Vision Large Language Models (ICML 2024) + """ + + def __init__( + self, + *, + subset: VLGuardSubset = VLGuardSubset.UNSAFES, + categories: list[VLGuardCategory] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + """ + Initialize the VLGuard dataset loader. + + Args: + subset (VLGuardSubset): Which evaluation subset to load. Defaults to UNSAFES. + categories (list[VLGuardCategory] | None): List of VLGuard categories to filter by. + If None, all categories are included. + max_examples (int | None): Maximum number of multimodal examples to fetch. Each example + produces 2 prompts (text + image). If None, fetches all examples. + token (str | None): HuggingFace authentication token for accessing the gated dataset. + If None, uses the default token from the environment or HuggingFace CLI login. + + Raises: + ValueError: If any of the specified categories are invalid. + """ + self.subset = subset + self.categories = categories + self.max_examples = max_examples + self.token = token + self.source = f"https://huggingface.co/datasets/{_HF_REPO_ID}" + + if categories is not None: + valid_categories = {cat.value for cat in VLGuardCategory} + invalid_categories = { + cat.value if isinstance(cat, VLGuardCategory) else cat for cat in categories + } - valid_categories + if invalid_categories: + raise ValueError(f"Invalid VLGuard categories: {', '.join(invalid_categories)}") + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "vlguard" + + async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + """ + Fetch VLGuard multimodal examples and return as SeedDataset. + + Downloads the test split metadata and images from HuggingFace, then creates + multimodal prompts (text + image pairs linked by prompt_group_id) based on + the selected subset. + + Args: + cache (bool): Whether to cache downloaded files. Defaults to True. + + Returns: + SeedDataset: A SeedDataset containing the multimodal examples. + """ + logger.info(f"Loading VLGuard dataset (subset={self.subset.value})") + + metadata, image_dir = await self._download_dataset_files_async(cache=cache) + + prompts: list[SeedPrompt] = [] + + for example in metadata: + image_filename = example.get("image") + is_safe = example.get("safe") + category = example.get("harmful_category", "") + subcategory = example.get("harmful_subcategory", "") + instr_resp_raw = example.get("instr-resp") + if not instr_resp_raw or not isinstance(instr_resp_raw, list): + continue + instr_resp: list[dict[str, str]] = instr_resp_raw + + if not image_filename: + continue + + # Filter by subset (safe flag) + if self.subset == VLGuardSubset.UNSAFES and is_safe: + continue + if self.subset in (VLGuardSubset.SAFE_UNSAFES, VLGuardSubset.SAFE_SAFES) and not is_safe: + continue + + # Filter by categories + if self.categories is not None: + category_values = {cat.value for cat in self.categories} + if category not in category_values: + continue + + instruction = self._extract_instruction(instr_resp) + if not instruction: + continue + + image_path = image_dir / image_filename + if not image_path.exists(): + logger.warning(f"Image not found: {image_path}") + continue + + group_id = uuid.uuid4() + + text_prompt = SeedPrompt( + value=instruction, + data_type="text", + name="VLGuard Text", + dataset_name=self.dataset_name, + harm_categories=[category], + description=f"Text component of VLGuard multimodal prompt ({self.subset.value}).", + source=self.source, + prompt_group_id=group_id, + sequence=0, + metadata={ + "category": category, + "subcategory": subcategory, + "subset": self.subset.value, + "safe_image": is_safe, + }, + ) + + image_prompt = SeedPrompt( + value=str(image_path), + data_type="image_path", + name="VLGuard Image", + dataset_name=self.dataset_name, + harm_categories=[category], + description=f"Image component of VLGuard multimodal prompt ({self.subset.value}).", + source=self.source, + prompt_group_id=group_id, + sequence=1, + metadata={ + "category": category, + "subcategory": subcategory, + "subset": self.subset.value, + "safe_image": is_safe, + "original_filename": image_filename, + }, + ) + + prompts.append(text_prompt) + prompts.append(image_prompt) + + # len(prompts) is divided by two since each example produces one image and one text prompt. + if self.max_examples is not None and len(prompts) >= self.max_examples * 2: + break + + logger.info(f"Successfully loaded {len(prompts)} prompts from VLGuard dataset ({self.subset.value})") + + return SeedDataset(seeds=prompts, dataset_name=self.dataset_name) + + def _extract_instruction(self, instr_resp: list[dict[str, str]]) -> str | None: + """ + Extract the instruction text from an example based on the current subset. + + Args: + instr_resp (list[dict[str, str]]): List of instruction-response dictionaries from VLGuard. + + Returns: + str | None: The instruction text, or None if not found for the given subset. + """ + if self.subset == VLGuardSubset.UNSAFES: + if instr_resp and "instruction" in instr_resp[0]: + return str(instr_resp[0]["instruction"]) + elif self.subset == VLGuardSubset.SAFE_UNSAFES: + for item in instr_resp: + if "unsafe_instruction" in item: + return str(item["unsafe_instruction"]) + elif self.subset == VLGuardSubset.SAFE_SAFES: + for item in instr_resp: + if "safe_instruction" in item: + return str(item["safe_instruction"]) + return None + + async def _download_dataset_files_async(self, *, cache: bool = True) -> tuple[list[dict[str, str]], Path]: + """ + Download VLGuard metadata and images from HuggingFace. + + Args: + cache (bool): Whether to use cached files if available. + + Returns: + tuple[list[dict], Path]: Tuple of (metadata list, image directory path). + """ + from huggingface_hub import hf_hub_download + + cache_dir = DB_DATA_PATH / "seed-prompt-entries" / "vlguard" + cache_dir.mkdir(parents=True, exist_ok=True) + + json_path = cache_dir / "test.json" + image_dir = cache_dir / "test" + + # Use cache if available + if cache and json_path.exists() and image_dir.exists() and any(image_dir.iterdir()): + logger.info("Using cached VLGuard dataset") + with open(json_path, encoding="utf-8") as f: + metadata = json.load(f) + return metadata, image_dir + + logger.info("Downloading VLGuard dataset from HuggingFace...") + + def _download_sync() -> tuple[str, str]: + json_file = hf_hub_download( + repo_id=_HF_REPO_ID, + filename="test.json", + repo_type="dataset", + local_dir=str(cache_dir), + token=self.token, + ) + zip_file = hf_hub_download( + repo_id=_HF_REPO_ID, + filename="test.zip", + repo_type="dataset", + local_dir=str(cache_dir), + token=self.token, + ) + return json_file, zip_file + + await asyncio.to_thread(_download_sync) + + # Extract images from zip + zip_path = cache_dir / "test.zip" + if zip_path.exists(): + logger.info("Extracting VLGuard test images...") + with zipfile.ZipFile(str(zip_path), "r") as zf: + zf.extractall(str(cache_dir)) + + with open(json_path, encoding="utf-8") as f: + metadata = json.load(f) + + return metadata, image_dir diff --git a/tests/unit/datasets/test_vlguard_dataset.py b/tests/unit/datasets/test_vlguard_dataset.py new file mode 100644 index 000000000..29ad7eb85 --- /dev/null +++ b/tests/unit/datasets/test_vlguard_dataset.py @@ -0,0 +1,417 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import zipfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.datasets.seed_datasets.remote.vlguard_dataset import ( + VLGuardCategory, + VLGuardSubset, + _VLGuardDataset, +) +from pyrit.models import SeedDataset, SeedPrompt + + +@pytest.fixture +def mock_vlguard_metadata(): + """Mock VLGuard test.json metadata.""" + return [ + { + "id": "test_001", + "image": "unsafe_001.jpg", + "safe": False, + "harmful_category": "risky behavior", + "harmful_subcategory": "violence", + "instr-resp": [ + { + "instruction": "Describe what is happening in this image.", + "response": "I cannot describe this violent content.", + } + ], + }, + { + "id": "test_002", + "image": "safe_001.jpg", + "safe": True, + "harmful_category": "deception", + "harmful_subcategory": "disinformation", + "instr-resp": [ + { + "safe_instruction": "Describe the scene in this image.", + "response": "The image shows a peaceful park.", + }, + { + "unsafe_instruction": "How can I use this image to spread misinformation?", + "response": "I cannot help with spreading misinformation.", + }, + ], + }, + { + "id": "test_003", + "image": "unsafe_002.jpg", + "safe": False, + "harmful_category": "privacy", + "harmful_subcategory": "personal data", + "instr-resp": [ + { + "instruction": "Read the personal information shown in this image.", + "response": "I cannot extract personal information from images.", + } + ], + }, + ] + + +class TestVLGuardDataset: + """Test the VLGuard dataset loader.""" + + def test_dataset_name(self): + """Test dataset_name property.""" + loader = _VLGuardDataset() + assert loader.dataset_name == "vlguard" + + def test_default_subset_is_unsafes(self): + """Test default subset is UNSAFES.""" + loader = _VLGuardDataset() + assert loader.subset == VLGuardSubset.UNSAFES + + def test_invalid_category_raises(self): + """Test that invalid categories raise ValueError.""" + # Create a mock that looks like an enum with an invalid value + invalid_cat = MagicMock(spec=VLGuardCategory) + invalid_cat.value = "InvalidCategory" + # Need to make isinstance check fail so it goes to str path + with pytest.raises(ValueError, match="Invalid VLGuard categories"): + _VLGuardDataset(categories=[invalid_cat]) + + def test_valid_categories_accepted(self): + """Test that valid categories are accepted.""" + loader = _VLGuardDataset(categories=[VLGuardCategory.PRIVACY, VLGuardCategory.DECEPTION]) + assert len(loader.categories) == 2 + + @pytest.mark.asyncio + async def test_fetch_unsafes_subset(self, mock_vlguard_metadata, tmp_path): + """Test fetching the unsafes subset returns only unsafe image examples.""" + image_dir = tmp_path / "test" + image_dir.mkdir() + (image_dir / "unsafe_001.jpg").write_bytes(b"fake image") + (image_dir / "unsafe_002.jpg").write_bytes(b"fake image") + + loader = _VLGuardDataset(subset=VLGuardSubset.UNSAFES) + + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), + ): + dataset = await loader.fetch_dataset() + + assert isinstance(dataset, SeedDataset) + # 2 unsafe examples × 2 prompts each = 4 prompts + assert len(dataset.seeds) == 4 + assert all(isinstance(p, SeedPrompt) for p in dataset.seeds) + + text_prompts = [p for p in dataset.seeds if p.data_type == "text"] + assert len(text_prompts) == 2 + assert text_prompts[0].value == "Describe what is happening in this image." + assert text_prompts[0].metadata["subset"] == "unsafes" + assert text_prompts[0].metadata["safe_image"] is False + + @pytest.mark.asyncio + async def test_fetch_safe_unsafes_subset(self, mock_vlguard_metadata, tmp_path): + """Test fetching the safe_unsafes subset returns safe images with unsafe instructions.""" + image_dir = tmp_path / "test" + image_dir.mkdir() + (image_dir / "safe_001.jpg").write_bytes(b"fake image") + + loader = _VLGuardDataset(subset=VLGuardSubset.SAFE_UNSAFES) + + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), + ): + dataset = await loader.fetch_dataset() + + assert len(dataset.seeds) == 2 # 1 example × 2 prompts + text_prompts = [p for p in dataset.seeds if p.data_type == "text"] + assert text_prompts[0].value == "How can I use this image to spread misinformation?" + assert text_prompts[0].metadata["safe_image"] is True + + @pytest.mark.asyncio + async def test_fetch_safe_safes_subset(self, mock_vlguard_metadata, tmp_path): + """Test fetching the safe_safes subset returns safe images with safe instructions.""" + image_dir = tmp_path / "test" + image_dir.mkdir() + (image_dir / "safe_001.jpg").write_bytes(b"fake image") + + loader = _VLGuardDataset(subset=VLGuardSubset.SAFE_SAFES) + + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), + ): + dataset = await loader.fetch_dataset() + + assert len(dataset.seeds) == 2 # 1 example × 2 prompts + text_prompts = [p for p in dataset.seeds if p.data_type == "text"] + assert text_prompts[0].value == "Describe the scene in this image." + + @pytest.mark.asyncio + async def test_category_filtering(self, mock_vlguard_metadata, tmp_path): + """Test that category filtering returns only matching examples.""" + image_dir = tmp_path / "test" + image_dir.mkdir() + (image_dir / "unsafe_002.jpg").write_bytes(b"fake image") + + loader = _VLGuardDataset( + subset=VLGuardSubset.UNSAFES, + categories=[VLGuardCategory.PRIVACY], + ) + + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), + ): + dataset = await loader.fetch_dataset() + + assert len(dataset.seeds) == 2 # Only the Privacy example + text_prompts = [p for p in dataset.seeds if p.data_type == "text"] + assert text_prompts[0].harm_categories == ["privacy"] + + @pytest.mark.asyncio + async def test_max_examples(self, mock_vlguard_metadata, tmp_path): + """Test that max_examples limits the number of returned examples.""" + image_dir = tmp_path / "test" + image_dir.mkdir() + (image_dir / "unsafe_001.jpg").write_bytes(b"fake image") + (image_dir / "unsafe_002.jpg").write_bytes(b"fake image") + + loader = _VLGuardDataset(subset=VLGuardSubset.UNSAFES, max_examples=1) + + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), + ): + dataset = await loader.fetch_dataset() + + # max_examples=1 → 1 example × 2 prompts = 2 prompts + assert len(dataset.seeds) == 2 + + @pytest.mark.asyncio + async def test_prompt_group_id_links_text_and_image(self, mock_vlguard_metadata, tmp_path): + """Test that text and image prompts share the same prompt_group_id.""" + image_dir = tmp_path / "test" + image_dir.mkdir() + (image_dir / "unsafe_001.jpg").write_bytes(b"fake image") + (image_dir / "unsafe_002.jpg").write_bytes(b"fake image") + + loader = _VLGuardDataset(subset=VLGuardSubset.UNSAFES) + + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), + ): + dataset = await loader.fetch_dataset() + + # Each pair should share a group_id + text_prompt = dataset.seeds[0] + image_prompt = dataset.seeds[1] + assert text_prompt.prompt_group_id == image_prompt.prompt_group_id + assert text_prompt.data_type == "text" + assert image_prompt.data_type == "image_path" + assert text_prompt.sequence == 0 + assert image_prompt.sequence == 1 + + @pytest.mark.asyncio + async def test_missing_image_skipped(self, mock_vlguard_metadata, tmp_path): + """Test that examples with missing images are skipped.""" + image_dir = tmp_path / "test" + image_dir.mkdir() + # Only create one of the two unsafe images + (image_dir / "unsafe_001.jpg").write_bytes(b"fake image") + + loader = _VLGuardDataset(subset=VLGuardSubset.UNSAFES) + + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), + ): + dataset = await loader.fetch_dataset() + + # Only 1 example should be included (the one with the existing image) + assert len(dataset.seeds) == 2 + + @pytest.mark.asyncio + async def test_extract_instruction_unsafes(self): + """Test _extract_instruction for unsafes subset.""" + loader = _VLGuardDataset(subset=VLGuardSubset.UNSAFES) + instr_resp = [{"instruction": "Test instruction", "response": "Test response"}] + assert loader._extract_instruction(instr_resp) == "Test instruction" + + @pytest.mark.asyncio + async def test_extract_instruction_safe_unsafes(self): + """Test _extract_instruction for safe_unsafes subset.""" + loader = _VLGuardDataset(subset=VLGuardSubset.SAFE_UNSAFES) + instr_resp = [ + {"safe_instruction": "Safe question", "response": "Safe answer"}, + {"unsafe_instruction": "Unsafe question", "response": "Refusal"}, + ] + assert loader._extract_instruction(instr_resp) == "Unsafe question" + + @pytest.mark.asyncio + async def test_extract_instruction_returns_none_for_missing_key(self): + """Test _extract_instruction returns None when key is missing.""" + loader = _VLGuardDataset(subset=VLGuardSubset.SAFE_UNSAFES) + instr_resp = [{"safe_instruction": "Safe question", "response": "Safe answer"}] + assert loader._extract_instruction(instr_resp) is None + + @pytest.mark.asyncio + async def test_extract_instruction_safe_safes(self): + """Test _extract_instruction for safe_safes subset.""" + loader = _VLGuardDataset(subset=VLGuardSubset.SAFE_SAFES) + instr_resp = [ + {"safe_instruction": "Describe the park", "response": "A peaceful park."}, + ] + assert loader._extract_instruction(instr_resp) == "Describe the park" + + @pytest.mark.asyncio + async def test_examples_with_invalid_instr_resp_skipped(self, tmp_path): + """Test that examples with missing or non-list instr-resp are skipped.""" + metadata = [ + {"image": "img1.jpg", "safe": False, "harmful_category": "privacy", "harmful_subcategory": "personal data"}, + { + "image": "img2.jpg", + "safe": False, + "harmful_category": "privacy", + "harmful_subcategory": "personal data", + "instr-resp": "not a list", + }, + ] + image_dir = tmp_path / "test" + image_dir.mkdir() + (image_dir / "img1.jpg").write_bytes(b"fake") + (image_dir / "img2.jpg").write_bytes(b"fake") + + loader = _VLGuardDataset(subset=VLGuardSubset.UNSAFES) + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(metadata, image_dir)), + ): + with pytest.raises(ValueError, match="SeedDataset cannot be empty"): + await loader.fetch_dataset() + + @pytest.mark.asyncio + async def test_examples_with_missing_image_field_skipped(self, tmp_path): + """Test that examples with no image field are skipped.""" + metadata = [ + { + "safe": False, + "harmful_category": "privacy", + "harmful_subcategory": "personal data", + "instr-resp": [{"instruction": "Describe this.", "response": "No."}], + }, + ] + image_dir = tmp_path / "test" + image_dir.mkdir() + + loader = _VLGuardDataset(subset=VLGuardSubset.UNSAFES) + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(metadata, image_dir)), + ): + with pytest.raises(ValueError, match="SeedDataset cannot be empty"): + await loader.fetch_dataset() + + @pytest.mark.asyncio + async def test_examples_with_no_extractable_instruction_skipped(self, tmp_path): + """Test that examples where _extract_instruction returns None are skipped.""" + metadata = [ + { + "image": "img.jpg", + "safe": False, + "harmful_category": "privacy", + "harmful_subcategory": "personal data", + "instr-resp": [{"response": "No instruction key here."}], + }, + ] + image_dir = tmp_path / "test" + image_dir.mkdir() + (image_dir / "img.jpg").write_bytes(b"fake") + + loader = _VLGuardDataset(subset=VLGuardSubset.UNSAFES) + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(metadata, image_dir)), + ): + with pytest.raises(ValueError, match="SeedDataset cannot be empty"): + await loader.fetch_dataset() + + @pytest.mark.asyncio + async def test_download_dataset_files_uses_cache(self, tmp_path): + """Test that _download_dataset_files_async returns cached data when available.""" + cache_dir = tmp_path / "seed-prompt-entries" / "vlguard" + cache_dir.mkdir(parents=True) + + json_path = cache_dir / "test.json" + image_dir = cache_dir / "test" + image_dir.mkdir() + (image_dir / "img.jpg").write_bytes(b"fake") + + test_metadata = [{"image": "img.jpg", "safe": False}] + json_path.write_text(json.dumps(test_metadata), encoding="utf-8") + + loader = _VLGuardDataset() + + with patch("pyrit.datasets.seed_datasets.remote.vlguard_dataset.DB_DATA_PATH", tmp_path): + metadata, result_dir = await loader._download_dataset_files_async(cache=True) + + assert metadata == test_metadata + assert result_dir == image_dir + + @pytest.mark.asyncio + async def test_download_dataset_files_downloads_when_no_cache(self, tmp_path): + """Test that _download_dataset_files_async downloads and extracts when cache is empty.""" + cache_dir = tmp_path / "seed-prompt-entries" / "vlguard" + + test_metadata = [{"image": "test/img.jpg", "safe": False}] + + def mock_hf_download(*, repo_id, filename, repo_type, local_dir, token): + local = Path(local_dir) + local.mkdir(parents=True, exist_ok=True) + if filename == "test.json": + path = local / "test.json" + path.write_text(json.dumps(test_metadata), encoding="utf-8") + return str(path) + if filename == "test.zip": + zip_path = local / "test.zip" + img_dir = local / "test" + img_dir.mkdir(exist_ok=True) + (img_dir / "img.jpg").write_bytes(b"fake image") + with zipfile.ZipFile(str(zip_path), "w") as zf: + zf.write(str(img_dir / "img.jpg"), "test/img.jpg") + return str(zip_path) + return None + + loader = _VLGuardDataset(token="fake_token") + + with ( + patch("pyrit.datasets.seed_datasets.remote.vlguard_dataset.DB_DATA_PATH", tmp_path), + patch("huggingface_hub.hf_hub_download", side_effect=mock_hf_download), + ): + metadata, result_dir = await loader._download_dataset_files_async(cache=False) + + assert metadata == test_metadata + assert result_dir == cache_dir / "test"