From 7a998ebc4f2e82b1dd01650217c3796031db0eb5 Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Wed, 22 Apr 2026 16:12:51 -0400 Subject: [PATCH 1/8] Memory: add has_converters filter; flip converter_classes=[] to "no filter" Add a tri-state has_converters param to get_attack_results. Previously converter_classes=[] meant "attacks with no converters"; it now means "no filter". Use has_converters=False for the old behavior. BREAKING CHANGE: converter_classes=[] no longer filters to zero-converter attacks. No in-repo callers relied on this. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/azure_sql_memory.py | 53 +++--- pyrit/memory/memory_interface.py | 73 +++++--- pyrit/memory/sqlite_memory.py | 21 ++- .../test_interface_attack_results.py | 168 +++++++++++++++--- 4 files changed, 236 insertions(+), 79 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index d93d2bd4a..2a0377f52 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -7,14 +7,14 @@ from collections.abc import MutableSequence, Sequence from contextlib import closing, suppress from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast from sqlalchemy import and_, create_engine, event, exists, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker from sqlalchemy.orm.session import Session -from sqlalchemy.sql.expression import TextClause +from sqlalchemy.sql.expression import ColumnElement, TextClause from pyrit.auth.azure_auth import AzureAuth from pyrit.common import default_values @@ -454,35 +454,40 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories ) ) - def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: + def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence[str]]) -> Any: """ - Get the SQL Azure implementation for filtering AttackResults by labels. + Azure SQL implementation for filtering AttackResults by labels. - Uses JSON_VALUE() function specific to SQL Azure with parameterized queries. - - Args: - labels (dict[str, str]): Dictionary of label key-value pairs to filter by. + Uses JSON_VALUE() with parameterized IN clauses. See + ``MemoryInterface._get_attack_result_label_condition`` for semantics. Returns: Any: SQLAlchemy exists subquery condition with bound parameters. """ - # Build JSON conditions for all labels with parameterized queries - label_conditions = [] - bindparams_dict = {} - for key, value in labels.items(): - param_name = f"label_{key}" - label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") - bindparams_dict[param_name] = str(value) - - combined_conditions = " AND ".join(label_conditions) - - return exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.labels.isnot(None), - text(f"ISJSON(labels) = 1 AND {combined_conditions}").bindparams(**bindparams_dict), + label_conditions: list[str] = [] + bindparams_dict: dict[str, str] = {} + for key, raw_value in labels.items(): + values = [raw_value] if isinstance(raw_value, str) else list(raw_value) + if not values: + continue + placeholders = [] + for idx, v in enumerate(values): + param_name = f"label_{key}_{idx}" + placeholders.append(f":{param_name}") + bindparams_dict[param_name] = str(v) + label_conditions.append(f"JSON_VALUE(labels, '$.{key}') IN ({', '.join(placeholders)})") + + base = [ + PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, + PromptMemoryEntry.labels.isnot(None), + ] + if label_conditions: + combined = " AND ".join(label_conditions) + base.append( + cast("ColumnElement[bool]", text(f"ISJSON(labels) = 1 AND {combined}").bindparams(**bindparams_dict)) ) - ) + + return exists().where(and_(*base)) def get_unique_attack_class_names(self) -> list[str]: """ diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index d020a4a0a..eb4d31bd4 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union -from sqlalchemy import MetaData, and_ +from sqlalchemy import MetaData, and_, not_, or_ from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.attributes import InstrumentedAttribute @@ -417,13 +417,18 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories """ @abc.abstractmethod - def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: + def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence[str]]) -> Any: """ Return a database-specific condition for filtering AttackResults by labels in the associated PromptMemoryEntry records. + Semantics: entries are AND-combined across label names; within a single + entry, a string value is an equality match and a sequence value is an + OR match over the listed values. An empty sequence is a no-op for that + label. See ``get_attack_results`` for examples. + Args: - labels: Dictionary of labels that must ALL be present. + labels: Label-name-to-value(s) map. Returns: Database-specific SQLAlchemy condition. @@ -1453,10 +1458,11 @@ def get_attack_results( objective: Optional[str] = None, objective_sha256: Optional[Sequence[str]] = None, outcome: Optional[str] = None, - attack_class: Optional[str] = None, + attack_classes: Optional[Sequence[str]] = None, converter_classes: Optional[Sequence[str]] = None, + has_converters: Optional[bool] = None, targeted_harm_categories: Optional[Sequence[str]] = None, - labels: Optional[dict[str, str]] = None, + labels: Optional[dict[str, str | Sequence[str]]] = None, identifier_filters: Optional[Sequence[IdentifierFilter]] = None, ) -> Sequence[AttackResult]: """ @@ -1470,11 +1476,17 @@ def get_attack_results( Defaults to None. outcome (Optional[str], optional): The outcome to filter by (success, failure, undetermined). Defaults to None. - attack_class (Optional[str], optional): Filter by exact attack class_name in attack_identifier. - Defaults to None. + attack_classes (Optional[Sequence[str]], optional): Filter by exact attack class_name in + attack_identifier. Returns attacks matching ANY of the listed class names (OR logic, + case-sensitive). An empty sequence applies no filter. Defaults to None. converter_classes (Optional[Sequence[str]], optional): Filter by converter class names. Returns only attacks that used ALL specified converters (AND logic, case-insensitive). + An empty sequence applies no filter. To filter by presence/absence of any converter, + use the ``has_converters`` parameter instead. Defaults to None. + has_converters (Optional[bool], optional): Filter by converter presence. + ``True`` returns only attacks that used at least one converter. ``False`` returns + only attacks that used no converters. ``None`` applies no filter. Defaults to None. targeted_harm_categories (Optional[Sequence[str]], optional): A list of targeted harm categories to filter results by. These targeted harm categories are associated with the prompts themselves, @@ -1482,9 +1494,14 @@ def get_attack_results( not necessarily one(s) that were found in the response. By providing a list, this means ALL categories in the list must be present. Defaults to None. - labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. - These labels are associated with the prompts themselves, used for custom tagging and tracking. - Defaults to None. + labels (Optional[dict[str, str | Sequence[str]]], optional): Filter results + by attack labels. Entries are AND-combined across label names; within a + single entry, a string value is an equality match and a sequence value is + an OR match over the listed values. An empty sequence applies no filter + for that label. Example: ``{"operator": "roakey", "operation": + ["roakey_op_a", "roakey_op_b"]}`` matches attacks where ``operator == + "roakey"`` AND (``operation == "roakey_op_a"`` OR ``operation == + "roakey_op_b"``). Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. @@ -1509,20 +1526,25 @@ def get_attack_results( if outcome: conditions.append(AttackResultEntry.outcome == outcome) - if attack_class: - # Use database-specific JSON query method + if attack_classes: + # Use database-specific JSON query method; OR-match across the provided class names. conditions.append( - self._get_condition_json_property_match( - json_column=AttackResultEntry.atomic_attack_identifier, - property_path="$.children.attack_technique.children.attack.class_name", - value=attack_class, - case_sensitive=True, + or_( + *[ + self._get_condition_json_property_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path="$.children.attack_technique.children.attack.class_name", + value=ac, + case_sensitive=True, + ) + for ac in attack_classes + ] ) ) - if converter_classes is not None: - # converter_classes=[] means "only attacks with no converters" - # converter_classes=["A","B"] means "must have all listed converters" + if converter_classes: + # Non-empty sequence filters by converter class names (ALL must be present). + # Empty sequence or None applies no filter. conditions.append( self._get_condition_json_array_match( json_column=AttackResultEntry.atomic_attack_identifier, @@ -1532,6 +1554,17 @@ def get_attack_results( ) ) + if has_converters is not None: + # Reuse the array-empty match (array_to_match=[]) as the "no converters" condition; + # invert it for "has at least one converter". + empty_condition = self._get_condition_json_array_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path="$.children.attack_technique.children.attack.children.request_converters", + array_element_path="$.class_name", + array_to_match=[], + ) + conditions.append(not_(empty_condition) if has_converters else empty_condition) + if targeted_harm_categories: warnings.warn( "The 'targeted_harm_categories' parameter is deprecated and will be removed in a future release.", diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index a4039c1b7..9ba32e917 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -620,11 +620,15 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories ) return targeted_harm_categories_subquery # noqa: RET504 - def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: + def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence[str]]) -> Any: """ SQLite implementation for filtering AttackResults by labels. Uses json_extract() function specific to SQLite. + Keys are AND-combined. For each key, a string value is an equality match; + a sequence value is an OR-within-key match (any listed value matches). + Empty sequences are no-ops (no constraint on that key). + Returns: Any: A SQLAlchemy subquery for filtering by labels. """ @@ -632,16 +636,21 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry - labels_subquery = exists().where( + per_key_conditions = [] + for key, raw_value in labels.items(): + values = [raw_value] if isinstance(raw_value, str) else list(raw_value) + if not values: + continue + col = func.json_extract(PromptMemoryEntry.labels, f"$.{key}") + per_key_conditions.append(col.in_(values)) + + return exists().where( and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, PromptMemoryEntry.labels.isnot(None), - and_( - *[func.json_extract(PromptMemoryEntry.labels, f"$.{key}") == value for key, value in labels.items()] - ), + and_(*per_key_conditions), ) ) - return labels_subquery # noqa: RET504 def get_unique_attack_class_names(self) -> list[str]: """ diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 6a3857c57..dadd67767 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -839,6 +839,50 @@ def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface) assert conversation_ids == {"conv_1", "conv_2"} +def test_get_attack_results_by_labels_or_within_key(sqlite_instance: MemoryInterface): + """Test that a sequence value for a label key matches any of the values (OR-within-key).""" + + message_pieces = [ + create_message_piece("conv_1", 1, labels={"operator": "alice"}), + create_message_piece("conv_2", 2, labels={"operator": "bob"}), + create_message_piece("conv_3", 3, labels={"operator": "charlie"}), + ] + sqlite_instance.add_message_pieces_to_memory(message_pieces=message_pieces) + + sqlite_instance.add_attack_results_to_memory( + attack_results=[ + create_attack_result("conv_1", 1), + create_attack_result("conv_2", 2), + create_attack_result("conv_3", 3), + ] + ) + + results = sqlite_instance.get_attack_results(labels={"operator": ["alice", "bob"]}) + assert {r.conversation_id for r in results} == {"conv_1", "conv_2"} + + +def test_get_attack_results_by_labels_or_within_key_and_across_keys(sqlite_instance: MemoryInterface): + """Test that OR-within-key composes with AND-across-keys.""" + + message_pieces = [ + # matches: operator in {alice, bob} AND operation == red + create_message_piece("conv_1", 1, labels={"operator": "alice", "operation": "red"}), + create_message_piece("conv_2", 2, labels={"operator": "bob", "operation": "red"}), + # fails operator constraint + create_message_piece("conv_3", 3, labels={"operator": "charlie", "operation": "red"}), + # fails operation constraint + create_message_piece("conv_4", 4, labels={"operator": "alice", "operation": "blue"}), + ] + sqlite_instance.add_message_pieces_to_memory(message_pieces=message_pieces) + + sqlite_instance.add_attack_results_to_memory( + attack_results=[create_attack_result(f"conv_{i}", i) for i in range(1, 5)] + ) + + results = sqlite_instance.get_attack_results(labels={"operator": ["alice", "bob"], "operation": ["red"]}) + assert {r.conversation_id for r in results} == {"conv_1", "conv_2"} + + def test_get_attack_results_by_harm_category_and_labels(sqlite_instance: MemoryInterface): """Test filtering attack results by both harm categories and labels.""" @@ -1155,47 +1199,68 @@ def _make_attack_result_with_identifier( ) -def test_get_attack_results_by_attack_class(sqlite_instance: MemoryInterface): - """Test filtering attack results by attack_class matches class_name in JSON.""" +def test_get_attack_results_by_attack_classes(sqlite_instance: MemoryInterface): + """Test filtering attack results by attack_classes matches class_name in JSON.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack") + results = sqlite_instance.get_attack_results(attack_classes=["CrescendoAttack"]) assert len(results) == 2 assert {r.conversation_id for r in results} == {"conv_1", "conv_3"} -def test_get_attack_results_by_attack_class_no_match(sqlite_instance: MemoryInterface): - """Test that attack_class filter returns empty when nothing matches.""" +def test_get_attack_results_by_attack_classes_no_match(sqlite_instance: MemoryInterface): + """Test that attack_classes filter returns empty when nothing matches.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - results = sqlite_instance.get_attack_results(attack_class="NonExistentAttack") + results = sqlite_instance.get_attack_results(attack_classes=["NonExistentAttack"]) assert len(results) == 0 -def test_get_attack_results_by_attack_class_case_sensitive(sqlite_instance: MemoryInterface): - """Test that attack_class filter is case-sensitive (exact match).""" +def test_get_attack_results_by_attack_classes_case_sensitive(sqlite_instance: MemoryInterface): + """Test that attack_classes filter is case-sensitive (exact match).""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - results = sqlite_instance.get_attack_results(attack_class="crescendoattack") + results = sqlite_instance.get_attack_results(attack_classes=["crescendoattack"]) assert len(results) == 0 -def test_get_attack_results_by_attack_class_no_identifier(sqlite_instance: MemoryInterface): - """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_class filter.""" +def test_get_attack_results_by_attack_classes_no_identifier(sqlite_instance: MemoryInterface): + """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_classes filter.""" ar1 = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} ar2 = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) - results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack") + results = sqlite_instance.get_attack_results(attack_classes=["CrescendoAttack"]) assert len(results) == 1 assert results[0].conversation_id == "conv_2" +def test_get_attack_results_by_attack_classes_multi(sqlite_instance: MemoryInterface): + """Test that multiple attack_classes use OR logic — matches any of the listed class names.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + ar3 = _make_attack_result_with_identifier("conv_3", "TreeOfAttacksAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + results = sqlite_instance.get_attack_results(attack_classes=["CrescendoAttack", "ManualAttack"]) + assert {r.conversation_id for r in results} == {"conv_1", "conv_2"} + + +def test_get_attack_results_attack_classes_empty_returns_all(sqlite_instance: MemoryInterface): + """Test that attack_classes=[] behaves like None (no filter applied).""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + results = sqlite_instance.get_attack_results(attack_classes=[]) + assert len(results) == 2 + + def test_get_attack_results_converter_classes_none_returns_all(sqlite_instance: MemoryInterface): """Test that converter_classes=None (omitted) returns all attacks unfiltered.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) @@ -1207,8 +1272,8 @@ def test_get_attack_results_converter_classes_none_returns_all(sqlite_instance: assert len(results) == 3 -def test_get_attack_results_converter_classes_empty_matches_no_converters(sqlite_instance: MemoryInterface): - """Test that converter_classes=[] returns only attacks with no converters.""" +def test_get_attack_results_converter_classes_empty_returns_all(sqlite_instance: MemoryInterface): + """Test that converter_classes=[] behaves like None (no filter applied).""" ar_with_conv = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) ar_no_conv_none = _make_attack_result_with_identifier("conv_2", "Attack") # converter_ids=None ar_no_conv_empty = _make_attack_result_with_identifier("conv_3", "Attack", []) # converter_ids=[] @@ -1218,12 +1283,7 @@ def test_get_attack_results_converter_classes_empty_matches_no_converters(sqlite ) results = sqlite_instance.get_attack_results(converter_classes=[]) - conv_ids = {r.conversation_id for r in results} - # Should include attacks with no converters (None key, empty array, or empty identifier) - assert "conv_1" not in conv_ids, "Should not include attacks that have converters" - assert "conv_2" in conv_ids, "Should include attacks where converter key is absent (None)" - assert "conv_3" in conv_ids, "Should include attacks with empty converter list" - assert "conv_4" in conv_ids, "Should include attacks with empty attack_identifier" + assert len(results) == 4 def test_get_attack_results_converter_classes_single_match(sqlite_instance: MemoryInterface): @@ -1271,29 +1331,79 @@ def test_get_attack_results_converter_classes_no_match(sqlite_instance: MemoryIn assert len(results) == 0 -def test_get_attack_results_attack_class_and_converter_classes_combined(sqlite_instance: MemoryInterface): - """Test combining attack_type and converter_types filters.""" +def test_get_attack_results_attack_classes_and_converter_classes_combined(sqlite_instance: MemoryInterface): + """Test combining attack_classes and converter_classes filters.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack", ["Base64Converter"]) ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack", ["ROT13Converter"]) ar4 = _make_attack_result_with_identifier("conv_4", "CrescendoAttack") # No converters sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3, ar4]) - results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack", converter_classes=["Base64Converter"]) + results = sqlite_instance.get_attack_results( + attack_classes=["CrescendoAttack"], converter_classes=["Base64Converter"] + ) assert len(results) == 1 assert results[0].conversation_id == "conv_1" -def test_get_attack_results_attack_class_with_no_converters(sqlite_instance: MemoryInterface): - """Test combining attack_type with converter_types=[] (no converters).""" +def test_get_attack_results_attack_classes_converter_classes_empty_no_filter(sqlite_instance: MemoryInterface): + """Combining attack_classes with converter_classes=[] applies no converter filter.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") # No converters - ar3 = _make_attack_result_with_identifier("conv_3", "ManualAttack") # No converters + ar3 = _make_attack_result_with_identifier("conv_3", "ManualAttack") # Different class sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack", converter_classes=[]) - assert len(results) == 1 - assert results[0].conversation_id == "conv_2" + results = sqlite_instance.get_attack_results(attack_classes=["CrescendoAttack"], converter_classes=[]) + assert {r.conversation_id for r in results} == {"conv_1", "conv_2"} + + +def test_get_attack_results_has_converters_true(sqlite_instance: MemoryInterface): + """has_converters=True returns only attacks with at least one converter.""" + ar_with_conv = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) + ar_no_conv_none = _make_attack_result_with_identifier("conv_2", "Attack") # converter_ids=None + ar_no_conv_empty = _make_attack_result_with_identifier("conv_3", "Attack", []) # converter_ids=[] + ar_no_identifier = create_attack_result("conv_4", 4) # No identifier → stored as {} + sqlite_instance.add_attack_results_to_memory( + attack_results=[ar_with_conv, ar_no_conv_none, ar_no_conv_empty, ar_no_identifier] + ) + + results = sqlite_instance.get_attack_results(has_converters=True) + assert {r.conversation_id for r in results} == {"conv_1"} + + +def test_get_attack_results_has_converters_false(sqlite_instance: MemoryInterface): + """has_converters=False returns only attacks with zero converters.""" + ar_with_conv = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) + ar_no_conv_none = _make_attack_result_with_identifier("conv_2", "Attack") + ar_no_conv_empty = _make_attack_result_with_identifier("conv_3", "Attack", []) + ar_no_identifier = create_attack_result("conv_4", 4) + sqlite_instance.add_attack_results_to_memory( + attack_results=[ar_with_conv, ar_no_conv_none, ar_no_conv_empty, ar_no_identifier] + ) + + results = sqlite_instance.get_attack_results(has_converters=False) + assert {r.conversation_id for r in results} == {"conv_2", "conv_3", "conv_4"} + + +def test_get_attack_results_has_converters_none_returns_all(sqlite_instance: MemoryInterface): + """has_converters=None applies no filter.""" + ar_with_conv = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) + ar_no_conv = _make_attack_result_with_identifier("conv_2", "Attack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar_with_conv, ar_no_conv]) + + results = sqlite_instance.get_attack_results(has_converters=None) + assert len(results) == 2 + + +def test_get_attack_results_has_converters_false_combined_with_attack_classes(sqlite_instance: MemoryInterface): + """has_converters=False composes (AND) with attack_classes.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack", ["Base64Converter"]) + ar2 = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") # No converters + ar3 = _make_attack_result_with_identifier("conv_3", "ManualAttack") # No converters, different class + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + results = sqlite_instance.get_attack_results(attack_classes=["CrescendoAttack"], has_converters=False) + assert {r.conversation_id for r in results} == {"conv_2"} # ============================================================================ From 681c3dfa70bb6bea7eef5a5acb83b3100c191cbb Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Wed, 22 Apr 2026 16:13:06 -0400 Subject: [PATCH 2/8] Backend: rename attack_class->attack_types; forward has_converters Update the AttackService and REST route to forward the new has_converters param to memory, and rename the route query param from attack_class (scalar) to attack_types (repeated) for consistency with the other list-valued filters. BREAKING CHANGE: REST query param attack_class renamed to attack_types. Only in-tree consumer is the frontend, updated in the next commit. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/routes/attacks.py | 49 ++++++-- pyrit/backend/services/attack_service.py | 49 ++++++-- tests/unit/backend/test_api_routes.py | 117 +++++++++++++++++- tests/unit/backend/test_attack_service.py | 142 ++++++++++++++++++++-- 4 files changed, 322 insertions(+), 35 deletions(-) diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index 3fd8845dc..ed06360a3 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -9,6 +9,7 @@ """ import logging +from collections.abc import Sequence from typing import Literal, Optional from fastapi import APIRouter, HTTPException, Query, status @@ -38,21 +39,29 @@ router = APIRouter(prefix="/attacks", tags=["attacks"]) -def _parse_labels(label_params: Optional[list[str]]) -> Optional[dict[str, str]]: +def _parse_labels(label_params: Optional[list[str]]) -> Optional[dict[str, str | Sequence[str]]]: """ - Parse label query params in 'key:value' format to a dict. + Parse 'key:value' label query params into a dict grouping values by key. + + Repeating the same key produces OR-within-key semantics downstream + (e.g. ?label=operator:alice&label=operator:bob matches either operator). + Different keys are combined with AND. Returns: - Dict mapping label keys to values, or None if no valid labels. + Dict mapping each label key to a list of values, or None if no valid labels. """ if not label_params: return None - labels = {} + labels: dict[str, list[str]] = {} for param in label_params: if ":" in param: key, value = param.split(":", 1) - labels[key.strip()] = value.strip() - return labels if labels else None + labels.setdefault(key.strip(), []).append(value.strip()) + if not labels: + return None + # Widen value type to match the service signature (dict values are invariant). + widened: dict[str, str | Sequence[str]] = dict(labels) + return widened @router.get( @@ -60,12 +69,27 @@ def _parse_labels(label_params: Optional[list[str]]) -> Optional[dict[str, str]] response_model=AttackListResponse, ) async def list_attacks( - attack_type: Optional[str] = Query(None, description="Filter by exact attack type name"), + attack_types: Optional[list[str]] = Query( + None, + description="Filter by attack type names (repeatable). OR-matched, case-sensitive. " + "Omit to return all attacks regardless of type.", + ), converter_types: Optional[list[str]] = Query( None, - description="Filter by converter type names (repeatable, AND logic). " + description="Filter by converter type names (repeatable). " + "Combined per converter_types_match. " "Omit to return all attacks regardless of converters. " - "Pass with no values to match only no-converter attacks.", + "To filter for attacks with no converters, use has_converters=false.", + ), + converter_types_match: Literal["any", "all"] = Query( + "all", + description="How to combine multiple converter_types: 'any' (attack has at least one) " + "or 'all' (attack has every one). Defaults to 'all'.", + ), + has_converters: Optional[bool] = Query( + None, + description="Filter by converter presence. true = attacks with at least one converter; " + "false = attacks with no converters. Omit for no filter.", ), outcome: Optional[Literal["undetermined", "success", "failure"]] = Query(None, description="Filter by outcome"), label: Optional[list[str]] = Query(None, description="Filter by labels (format: key:value, repeatable)"), @@ -89,12 +113,15 @@ async def list_attacks( """ service = get_attack_service() labels = _parse_labels(label) - # Normalize converter_types: strip empty strings so ?converter_types= means "no converters" + # Strip empty strings from ?converter_types= parameters; an all-empty list + # collapses to [] which means "no filter" at the memory layer. if converter_types is not None: converter_types = [c for c in converter_types if c] return await service.list_attacks_async( - attack_type=attack_type, + attack_types=attack_types, converter_types=converter_types, + converter_types_match=converter_types_match, + has_converters=has_converters, outcome=outcome, labels=labels, min_turns=min_turns, diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 5c8ea20b6..9f5a51a71 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -64,6 +64,15 @@ from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer +def _converter_classes_on_attack(ar: AttackResult) -> set[str]: + """Return the set of converter class_name values for the attack (lowercased).""" + aid = ar.get_attack_strategy_identifier() + if not aid: + return set() + converter_ids = aid.get_child_list("request_converters") + return {c.class_name.lower() for c in converter_ids} + + class AttackService: """ Service for managing attacks. @@ -82,10 +91,12 @@ def __init__(self) -> None: async def list_attacks_async( self, *, - attack_type: Optional[str] = None, + attack_types: Optional[list[str]] = None, converter_types: Optional[list[str]] = None, + converter_types_match: Literal["any", "all"] = "all", + has_converters: Optional[bool] = None, outcome: Optional[Literal["undetermined", "success", "failure"]] = None, - labels: Optional[dict[str, str]] = None, + labels: Optional[dict[str, str | Sequence[str]]] = None, min_turns: Optional[int] = None, max_turns: Optional[int] = None, limit: int = 20, @@ -97,12 +108,23 @@ async def list_attacks_async( Queries AttackResult entries from the database. Args: - attack_type: Filter by exact attack type name (case-sensitive). - converter_types: Filter by converter usage. - None = no filter, [] = only attacks with no converters, - ["A", "B"] = only attacks using ALL specified converters (AND logic, case-insensitive). + attack_types: Filter by attack type names (repeatable, OR-matched, case-sensitive). + None or empty list applies no filter. + converter_types: Filter by converter class names (case-insensitive). + ``None`` or empty list applies no filter. Combination semantics for + multiple entries are controlled by ``converter_types_match``. For the + "only attacks with no converters" case, use ``has_converters=False``. + converter_types_match: How to combine multiple entries in ``converter_types``. + ``"all"`` (default) matches attacks that used every listed converter. + ``"any"`` matches attacks that used at least one of the listed converters. + Ignored when ``converter_types`` is None or has fewer than 2 entries. + has_converters: Filter by converter presence. ``True`` returns only attacks that + used at least one converter. ``False`` returns only attacks that used no + converters. ``None`` applies no filter. outcome: Filter by attack outcome. - labels: Filter by labels (all must match). + labels: Filter by labels. See ``MemoryInterface.get_attack_results`` for + semantics (AND across label names; string equality or sequence OR within + each name). min_turns: Filter by minimum executed turns. max_turns: Filter by maximum executed turns. limit: Maximum items to return. @@ -112,13 +134,22 @@ async def list_attacks_async( AttackListResponse with filtered and paginated attack summaries. """ # Phase 1: Query + lightweight filtering (no pieces needed) + # For converter_types, push AND-logic down to the DB for "all" mode and for + # the degenerate cases (None, [], single entry). For "any" mode with 2+ + # entries, skip the DB filter and apply OR-matching in Python below. + push_converter_filter = converter_types is None or len(converter_types) <= 1 or converter_types_match == "all" attack_results = self._memory.get_attack_results( outcome=outcome, labels=labels if labels else None, - attack_class=attack_type, - converter_classes=converter_types, + attack_classes=attack_types if attack_types else None, + converter_classes=converter_types if push_converter_filter else None, + has_converters=has_converters, ) + if not push_converter_filter: + requested = {c.lower() for c in converter_types or []} + attack_results = [ar for ar in attack_results if _converter_classes_on_attack(ar) & requested] + filtered: list[AttackResult] = [] for ar in attack_results: if min_turns is not None and ar.executed_turns < min_turns: diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index cbfc74f3b..b597f43a7 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -87,13 +87,15 @@ def test_list_attacks_with_filters(self, client: TestClient) -> None: response = client.get( "/api/attacks", - params={"attack_type": "CrescendoAttack", "outcome": "success", "limit": 10}, + params={"attack_types": ["CrescendoAttack"], "outcome": "success", "limit": 10}, ) assert response.status_code == status.HTTP_200_OK mock_service.list_attacks_async.assert_called_once_with( - attack_type="CrescendoAttack", + attack_types=["CrescendoAttack"], converter_types=None, + converter_types_match="all", + has_converters=None, outcome="success", labels=None, min_turns=None, @@ -102,6 +104,63 @@ def test_list_attacks_with_filters(self, client: TestClient) -> None: cursor=None, ) + def test_list_attacks_multi_attack_types(self, client: TestClient) -> None: + """Test that repeated attack_types query params are forwarded as a list.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks_async = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get( + "/api/attacks", + params=[("attack_types", "CrescendoAttack"), ("attack_types", "ManualAttack")], + ) + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_service.list_attacks_async.call_args.kwargs + assert call_kwargs["attack_types"] == ["CrescendoAttack", "ManualAttack"] + + def test_list_attacks_has_converters_true(self, client: TestClient) -> None: + """?has_converters=true is parsed as bool True and forwarded.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks_async = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks", params={"has_converters": "true"}) + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_service.list_attacks_async.call_args.kwargs + assert call_kwargs["has_converters"] is True + + def test_list_attacks_has_converters_false(self, client: TestClient) -> None: + """?has_converters=false is parsed as bool False and forwarded.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks_async = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks", params={"has_converters": "false"}) + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_service.list_attacks_async.call_args.kwargs + assert call_kwargs["has_converters"] is False + def test_create_attack_success(self, client: TestClient) -> None: """Test successful attack creation.""" now = datetime.now(timezone.utc) @@ -422,7 +481,7 @@ def test_list_attacks_with_labels(self, client: TestClient) -> None: # Verify labels were parsed and passed to service mock_service.list_attacks_async.assert_called_once() call_kwargs = mock_service.list_attacks_async.call_args[1] - assert call_kwargs["labels"] == {"env": "prod", "team": "red"} + assert call_kwargs["labels"] == {"env": ["prod"], "team": ["red"]} def test_get_attack_options(self, client: TestClient) -> None: """Test getting attack type options from attack results.""" @@ -467,7 +526,7 @@ def test_parse_labels_skips_param_without_colon(self, client: TestClient) -> Non assert response.status_code == status.HTTP_200_OK call_kwargs = mock_service.list_attacks_async.call_args[1] # Only the valid label should be parsed - assert call_kwargs["labels"] == {"env": "prod"} + assert call_kwargs["labels"] == {"env": ["prod"]} def test_parse_labels_all_invalid_returns_none(self, client: TestClient) -> None: """Test that _parse_labels returns None when all params lack colons.""" @@ -503,7 +562,7 @@ def test_parse_labels_value_with_extra_colons(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK call_kwargs = mock_service.list_attacks_async.call_args[1] - assert call_kwargs["labels"] == {"url": "http://example.com:8080"} + assert call_kwargs["labels"] == {"url": ["http://example.com:8080"]} def test_parse_labels_passes_keys_through_without_normalization(self, client: TestClient) -> None: """Test that label keys are passed through as-is (DB stores canonical keys after migration).""" @@ -521,7 +580,7 @@ def test_parse_labels_passes_keys_through_without_normalization(self, client: Te assert response.status_code == status.HTTP_200_OK call_kwargs = mock_service.list_attacks_async.call_args[1] - assert call_kwargs["labels"] == {"operator": "alice", "operation": "redteam"} + assert call_kwargs["labels"] == {"operator": ["alice"], "operation": ["redteam"]} def test_list_attacks_forwards_converter_types_param(self, client: TestClient) -> None: """Test that converter_types query params are forwarded to service.""" @@ -559,6 +618,52 @@ def test_list_attacks_normalizes_empty_converter_types(self, client: TestClient) call_kwargs = mock_service.list_attacks_async.call_args[1] assert call_kwargs["converter_types"] == [] + def test_list_attacks_groups_repeated_label_key_as_list(self, client: TestClient) -> None: + """Repeated label keys produce OR-within-key: value list is forwarded to service.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks_async = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks?label=operator:alice&label=operator:bob") + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_service.list_attacks_async.call_args[1] + assert call_kwargs["labels"] == {"operator": ["alice", "bob"]} + + def test_list_attacks_forwards_converter_types_match(self, client: TestClient) -> None: + """converter_types_match query param is forwarded verbatim to service.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks_async = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks?converter_types=A&converter_types=B&converter_types_match=any") + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_service.list_attacks_async.call_args[1] + assert call_kwargs["converter_types_match"] == "any" + + def test_list_attacks_rejects_invalid_converter_types_match(self, client: TestClient) -> None: + """Invalid converter_types_match value returns 422 per Literal contract.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks?converter_types_match=garbage") + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + def test_get_conversations_success(self, client: TestClient) -> None: """Test getting attack conversations returns service response.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 3d09fe6b7..e9f98cbb9 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -196,30 +196,52 @@ async def test_list_attacks_returns_attacks(self, attack_service, mock_memory) - assert result.items[0].attack_type == "Test Attack" @pytest.mark.asyncio - async def test_list_attacks_filters_by_attack_type_exact(self, attack_service, mock_memory) -> None: - """Test that list_attacks passes attack_type to memory layer.""" + async def test_list_attacks_filters_by_attack_types_exact(self, attack_service, mock_memory) -> None: + """Test that list_attacks passes attack_types to memory layer.""" ar1 = make_attack_result(conversation_id="attack-1", name="CrescendoAttack") mock_memory.get_attack_results.return_value = [ar1] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks_async(attack_type="CrescendoAttack") + result = await attack_service.list_attacks_async(attack_types=["CrescendoAttack"]) assert len(result.items) == 1 assert result.items[0].conversation_id == "attack-1" - # Verify attack_type was forwarded to the memory layer as attack_class + # Verify attack_types was forwarded to the memory layer as attack_classes call_kwargs = mock_memory.get_attack_results.call_args[1] - assert call_kwargs["attack_class"] == "CrescendoAttack" + assert call_kwargs["attack_classes"] == ["CrescendoAttack"] @pytest.mark.asyncio - async def test_list_attacks_attack_type_passed_to_memory(self, attack_service, mock_memory) -> None: - """Test that attack_type is forwarded to memory as attack_class for DB-level filtering.""" + async def test_list_attacks_attack_types_passed_to_memory(self, attack_service, mock_memory) -> None: + """Test that attack_types is forwarded to memory as attack_classes for DB-level filtering.""" mock_memory.get_attack_results.return_value = [] mock_memory.get_message_pieces.return_value = [] - await attack_service.list_attacks_async(attack_type="Crescendo") + await attack_service.list_attacks_async(attack_types=["Crescendo"]) call_kwargs = mock_memory.get_attack_results.call_args[1] - assert call_kwargs["attack_class"] == "Crescendo" + assert call_kwargs["attack_classes"] == ["Crescendo"] + + @pytest.mark.asyncio + async def test_list_attacks_filters_by_attack_types_multi(self, attack_service, mock_memory) -> None: + """Test that multiple attack_types are forwarded as a list to memory for OR-matching.""" + mock_memory.get_attack_results.return_value = [] + mock_memory.get_message_pieces.return_value = [] + + await attack_service.list_attacks_async(attack_types=["CrescendoAttack", "ManualAttack"]) + + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert call_kwargs["attack_classes"] == ["CrescendoAttack", "ManualAttack"] + + @pytest.mark.asyncio + async def test_list_attacks_attack_types_empty_list_coerced_to_none(self, attack_service, mock_memory) -> None: + """Test that attack_types=[] is coerced to None before reaching memory (no filter).""" + mock_memory.get_attack_results.return_value = [] + mock_memory.get_message_pieces.return_value = [] + + await attack_service.list_attacks_async(attack_types=[]) + + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert call_kwargs["attack_classes"] is None @pytest.mark.asyncio async def test_list_attacks_filters_by_no_converters(self, attack_service, mock_memory) -> None: @@ -232,6 +254,28 @@ async def test_list_attacks_filters_by_no_converters(self, attack_service, mock_ call_kwargs = mock_memory.get_attack_results.call_args[1] assert call_kwargs["converter_classes"] == [] + @pytest.mark.asyncio + async def test_list_attacks_forwards_has_converters_true(self, attack_service, mock_memory) -> None: + """has_converters=True is forwarded to memory.""" + mock_memory.get_attack_results.return_value = [] + mock_memory.get_message_pieces.return_value = [] + + await attack_service.list_attacks_async(has_converters=True) + + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert call_kwargs["has_converters"] is True + + @pytest.mark.asyncio + async def test_list_attacks_forwards_has_converters_false(self, attack_service, mock_memory) -> None: + """has_converters=False is forwarded to memory.""" + mock_memory.get_attack_results.return_value = [] + mock_memory.get_message_pieces.return_value = [] + + await attack_service.list_attacks_async(has_converters=False) + + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert call_kwargs["has_converters"] is False + @pytest.mark.asyncio async def test_list_attacks_filters_by_converter_types_and_logic(self, attack_service, mock_memory) -> None: """Test that list_attacks passes converter_types to memory layer.""" @@ -273,6 +317,86 @@ async def test_list_attacks_filters_by_converter_types_and_logic(self, attack_se call_kwargs = mock_memory.get_attack_results.call_args[1] assert call_kwargs["converter_classes"] == ["Base64Converter", "ROT13Converter"] + @pytest.mark.asyncio + async def test_list_attacks_converter_match_all_explicit_pushes_to_memory( + self, attack_service, mock_memory + ) -> None: + """Explicit converter_types_match='all' still pushes converter filter to memory.""" + mock_memory.get_attack_results.return_value = [] + mock_memory.get_message_pieces.return_value = [] + + await attack_service.list_attacks_async( + converter_types=["Base64Converter", "ROT13Converter"], + converter_types_match="all", + ) + + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert call_kwargs["converter_classes"] == ["Base64Converter", "ROT13Converter"] + + @pytest.mark.asyncio + async def test_list_attacks_converter_match_any_single_converter_pushes_to_memory( + self, attack_service, mock_memory + ) -> None: + """Degenerate case: converter_types_match='any' with one converter still pushes to memory.""" + mock_memory.get_attack_results.return_value = [] + mock_memory.get_message_pieces.return_value = [] + + await attack_service.list_attacks_async( + converter_types=["Base64Converter"], + converter_types_match="any", + ) + + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert call_kwargs["converter_classes"] == ["Base64Converter"] + + @pytest.mark.asyncio + async def test_list_attacks_converter_match_any_filters_in_python(self, attack_service, mock_memory) -> None: + """converter_types_match='any' with 2+ converters skips DB pushdown and filters in Python. + + Fixture: attack-1 has Base64 only, attack-2 has ROT13 only, attack-3 has neither. + Request ['Base64Converter', 'ROT13Converter'] with any -> attack-1 and attack-2 returned. + """ + + def _ar_with_converters(conv_id: str, converter_class_names: list[str]) -> AttackResult: + ar = make_attack_result(conversation_id=conv_id, name=f"Attack {conv_id}") + ar.atomic_attack_identifier = build_atomic_attack_identifier( + attack_identifier=ComponentIdentifier( + class_name=f"Attack {conv_id}", + class_module="pyrit.backend", + children={ + "request_converters": [ + ComponentIdentifier( + class_name=name, + class_module="pyrit.converters", + params={ + "supported_input_types": ("text",), + "supported_output_types": ("text",), + }, + ) + for name in converter_class_names + ], + }, + ), + ) + return ar + + ar1 = _ar_with_converters("attack-1", ["Base64Converter"]) + ar2 = _ar_with_converters("attack-2", ["ROT13Converter"]) + ar3 = _ar_with_converters("attack-3", ["LeetspeakConverter"]) + mock_memory.get_attack_results.return_value = [ar1, ar2, ar3] + mock_memory.get_message_pieces.return_value = [] + + result = await attack_service.list_attacks_async( + converter_types=["Base64Converter", "ROT13Converter"], + converter_types_match="any", + ) + + # attack-3 is filtered out in Python; attack-1 and attack-2 remain + assert {item.conversation_id for item in result.items} == {"attack-1", "attack-2"} + # Memory was called WITHOUT converter_classes so pushdown is skipped + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert "converter_classes" not in call_kwargs or call_kwargs["converter_classes"] is None + @pytest.mark.asyncio async def test_list_attacks_filters_by_min_turns(self, attack_service, mock_memory) -> None: """Test that list_attacks filters by minimum executed turns.""" From 37d07d76301ca47a58e19d25e8a991159de922fc Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Wed, 22 Apr 2026 16:13:18 -0400 Subject: [PATCH 3/8] Frontend: migrate Attack History filters to multi-select Comboboxes Replace Dropdowns with searchable multi-select Comboboxes for attack class, converter, operator, and operation filters. The converter Combobox includes a "(No converters)" sentinel option that maps to has_converters=false. A match-mode Switch (ANY|ALL) appears when two or more converters are selected. Adapts AttackHistory.fetchAttacks to send attack_types (repeated), converter_types_match, and has_converters to the API. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../components/History/AttackHistory.test.tsx | 137 +++++++++++++++++- .../src/components/History/AttackHistory.tsx | 26 +++- .../History/HistoryFiltersBar.test.tsx | 122 +++++++++++++++- .../components/History/HistoryFiltersBar.tsx | 123 +++++++++++----- .../src/components/History/historyFilters.ts | 22 ++- frontend/src/services/api.ts | 4 +- 6 files changed, 372 insertions(+), 62 deletions(-) diff --git a/frontend/src/components/History/AttackHistory.test.tsx b/frontend/src/components/History/AttackHistory.test.tsx index a4dfea3ea..1c7d150ff 100644 --- a/frontend/src/components/History/AttackHistory.test.tsx +++ b/frontend/src/components/History/AttackHistory.test.tsx @@ -695,7 +695,7 @@ describe('AttackHistory', () => { fireEvent.click(screen.getByText('CrescendoAttack')) expect(onFiltersChange).toHaveBeenCalledWith( - expect.objectContaining({ attackClass: 'CrescendoAttack' }) + expect.objectContaining({ attackClasses: ['CrescendoAttack'] }) ) }) @@ -767,7 +767,7 @@ describe('AttackHistory', () => { fireEvent.click(screen.getByText('Base64Converter')) expect(onFiltersChange).toHaveBeenCalledWith( - expect.objectContaining({ converter: 'Base64Converter' }) + expect.objectContaining({ converter: ['Base64Converter'] }) ) }) @@ -808,7 +808,7 @@ describe('AttackHistory', () => { fireEvent.click(screen.getByText('alice')) expect(onFiltersChange).toHaveBeenCalledWith( - expect.objectContaining({ operator: 'alice' }) + expect.objectContaining({ operator: ['alice'] }) ) }) @@ -849,7 +849,7 @@ describe('AttackHistory', () => { fireEvent.click(screen.getByText('op_alpha')) expect(onFiltersChange).toHaveBeenCalledWith( - expect.objectContaining({ operation: 'op_alpha' }) + expect.objectContaining({ operation: ['op_alpha'] }) ) }) @@ -900,9 +900,9 @@ describe('AttackHistory', () => { const onFiltersChange = jest.fn() const activeFilters = { ...DEFAULT_HISTORY_FILTERS, - attackClass: 'CrescendoAttack', + attackClasses: ['CrescendoAttack'], outcome: 'success', - operator: 'alice', + operator: ['alice'], } render( @@ -961,4 +961,129 @@ describe('AttackHistory', () => { expect.objectContaining({ otherLabels: expect.any(Array), labelSearchText: '' }) ) }) + + it('should forward multi-select attackClasses as attack_types array to the API', async () => { + mockedAttacksApi.listAttacks.mockResolvedValue({ + items: [], + pagination: { limit: 25, has_more: false }, + }) + mockedAttacksApi.getAttackOptions.mockResolvedValue({ attack_types: [] }) + mockedAttacksApi.getConverterOptions.mockResolvedValue({ converter_types: [] }) + mockedLabelsApi.getLabels.mockResolvedValue({ source: 'attacks', labels: {} }) + + const filters = { + ...DEFAULT_HISTORY_FILTERS, + attackClasses: ['CrescendoAttack', 'RedTeamingAttack'], + } + + render( + + + + ) + + await waitFor(() => { + expect(mockedAttacksApi.listAttacks).toHaveBeenCalled() + }) + + expect(mockedAttacksApi.listAttacks).toHaveBeenCalledWith( + expect.objectContaining({ attack_types: ['CrescendoAttack', 'RedTeamingAttack'] }) + ) + }) + + it('should forward hasConverters=false without converter_types or converter_types_match', async () => { + mockedAttacksApi.listAttacks.mockResolvedValue({ + items: [], + pagination: { limit: 25, has_more: false }, + }) + mockedAttacksApi.getAttackOptions.mockResolvedValue({ attack_types: [] }) + mockedAttacksApi.getConverterOptions.mockResolvedValue({ converter_types: [] }) + mockedLabelsApi.getLabels.mockResolvedValue({ source: 'attacks', labels: {} }) + + const filters = { + ...DEFAULT_HISTORY_FILTERS, + hasConverters: false, + converter: [], + } + + render( + + + + ) + + await waitFor(() => { + expect(mockedAttacksApi.listAttacks).toHaveBeenCalled() + }) + + const callArgs = mockedAttacksApi.listAttacks.mock.calls[0][0] + expect(callArgs).toEqual(expect.objectContaining({ has_converters: false })) + expect(callArgs).not.toHaveProperty('converter_types') + expect(callArgs).not.toHaveProperty('converter_types_match') + }) + + it('should only send converter_types_match when two or more converters are selected', async () => { + mockedAttacksApi.listAttacks.mockResolvedValue({ + items: [], + pagination: { limit: 25, has_more: false }, + }) + mockedAttacksApi.getAttackOptions.mockResolvedValue({ attack_types: [] }) + mockedAttacksApi.getConverterOptions.mockResolvedValue({ converter_types: [] }) + mockedLabelsApi.getLabels.mockResolvedValue({ source: 'attacks', labels: {} }) + + // Case 1: single converter → converter_types_match is NOT sent + const singleFilters = { + ...DEFAULT_HISTORY_FILTERS, + converter: ['Base64Converter'], + converterMatchMode: 'all' as const, + } + + const { unmount } = render( + + + + ) + + await waitFor(() => { + expect(mockedAttacksApi.listAttacks).toHaveBeenCalled() + }) + + const singleCallArgs = mockedAttacksApi.listAttacks.mock.calls[0][0] + expect(singleCallArgs).toEqual(expect.objectContaining({ converter_types: ['Base64Converter'] })) + expect(singleCallArgs).not.toHaveProperty('converter_types_match') + + unmount() + jest.clearAllMocks() + mockedAttacksApi.listAttacks.mockResolvedValue({ + items: [], + pagination: { limit: 25, has_more: false }, + }) + mockedAttacksApi.getAttackOptions.mockResolvedValue({ attack_types: [] }) + mockedAttacksApi.getConverterOptions.mockResolvedValue({ converter_types: [] }) + mockedLabelsApi.getLabels.mockResolvedValue({ source: 'attacks', labels: {} }) + + // Case 2: two converters → converter_types_match IS sent + const multiFilters = { + ...DEFAULT_HISTORY_FILTERS, + converter: ['Base64Converter', 'ROT13Converter'], + converterMatchMode: 'all' as const, + } + + render( + + + + ) + + await waitFor(() => { + expect(mockedAttacksApi.listAttacks).toHaveBeenCalled() + }) + + expect(mockedAttacksApi.listAttacks).toHaveBeenCalledWith( + expect.objectContaining({ + converter_types: ['Base64Converter', 'ROT13Converter'], + converter_types_match: 'all', + }) + ) + }) }) diff --git a/frontend/src/components/History/AttackHistory.tsx b/frontend/src/components/History/AttackHistory.tsx index 2bb828ac0..e570d72c7 100644 --- a/frontend/src/components/History/AttackHistory.tsx +++ b/frontend/src/components/History/AttackHistory.tsx @@ -47,16 +47,18 @@ export default function AttackHistory({ onOpenAttack, filters, onFiltersChange } setError(null) try { const labelParams: string[] = [] - if (filters.operator) { labelParams.push(`operator:${filters.operator}`) } - if (filters.operation) { labelParams.push(`operation:${filters.operation}`) } + for (const op of filters.operator) { labelParams.push(`operator:${op}`) } + for (const op of filters.operation) { labelParams.push(`operation:${op}`) } labelParams.push(...filters.otherLabels) const response = await attacksApi.listAttacks({ limit: PAGE_SIZE, ...(pageCursor && { cursor: pageCursor }), - ...(filters.attackClass && { attack_type: filters.attackClass }), + ...(filters.attackClasses.length > 0 && { attack_types: filters.attackClasses }), ...(filters.outcome && { outcome: filters.outcome }), - ...(filters.converter && { converter_types: [filters.converter] }), + ...(filters.converter.length > 0 && { converter_types: filters.converter }), + ...(filters.converter.length >= 2 && { converter_types_match: filters.converterMatchMode }), + ...(filters.hasConverters !== undefined && { has_converters: filters.hasConverters }), ...(labelParams.length > 0 && { label: labelParams }), }) setAttacks(response.items.map(attack => ({ ...attack, labels: attack.labels ?? {} }))) @@ -68,7 +70,16 @@ export default function AttackHistory({ onOpenAttack, filters, onFiltersChange } } finally { setLoading(false) } - }, [filters.attackClass, filters.outcome, filters.converter, filters.operator, filters.operation, filters.otherLabels]) + }, [ + filters.attackClasses, + filters.outcome, + filters.converter, + filters.converterMatchMode, + filters.hasConverters, + filters.operator, + filters.operation, + filters.otherLabels, + ]) // Load filter options on mount useEffect(() => { @@ -132,8 +143,9 @@ export default function AttackHistory({ onOpenAttack, filters, onFiltersChange } } const hasActiveFilters = - filters.attackClass || filters.outcome || filters.converter || - filters.operator || filters.operation || filters.otherLabels.length > 0 + filters.attackClasses.length > 0 || filters.outcome || filters.converter.length > 0 || + filters.hasConverters !== undefined || + filters.operator.length > 0 || filters.operation.length > 0 || filters.otherLabels.length > 0 return (
diff --git a/frontend/src/components/History/HistoryFiltersBar.test.tsx b/frontend/src/components/History/HistoryFiltersBar.test.tsx index f69f7976a..48f2ae4f5 100644 --- a/frontend/src/components/History/HistoryFiltersBar.test.tsx +++ b/frontend/src/components/History/HistoryFiltersBar.test.tsx @@ -65,7 +65,7 @@ describe('HistoryFiltersBar', () => { it('should call onFiltersChange with defaults when reset is clicked', () => { const onFiltersChange = jest.fn() - const activeFilters = { ...DEFAULT_HISTORY_FILTERS, outcome: 'success', operator: 'alice' } + const activeFilters = { ...DEFAULT_HISTORY_FILTERS, outcome: 'success', operator: ['alice'] } render( @@ -98,7 +98,7 @@ describe('HistoryFiltersBar', () => { fireEvent.click(option) expect(onFiltersChange).toHaveBeenCalledWith( - expect.objectContaining({ attackClass: 'CrescendoAttack' }) + expect.objectContaining({ attackClasses: ['CrescendoAttack'] }) ) }) @@ -143,7 +143,7 @@ describe('HistoryFiltersBar', () => { fireEvent.click(option) expect(onFiltersChange).toHaveBeenCalledWith( - expect.objectContaining({ converter: 'ROT13Converter' }) + expect.objectContaining({ converter: ['ROT13Converter'] }) ) }) @@ -168,7 +168,7 @@ describe('HistoryFiltersBar', () => { fireEvent.click(option) expect(onFiltersChange).toHaveBeenCalledWith( - expect.objectContaining({ operator: 'bob' }) + expect.objectContaining({ operator: ['bob'] }) ) }) @@ -193,7 +193,7 @@ describe('HistoryFiltersBar', () => { fireEvent.click(option) expect(onFiltersChange).toHaveBeenCalledWith( - expect.objectContaining({ operation: 'op_beta' }) + expect.objectContaining({ operation: ['op_beta'] }) ) }) @@ -256,4 +256,116 @@ describe('HistoryFiltersBar', () => { expect(screen.getByTestId('reset-filters-btn')).toBeInTheDocument() }) + + it('should set hasConverters=false and clear converter list when "(No converters)" sentinel is selected', async () => { + const onFiltersChange = jest.fn() + const props = { + ...defaultProps, + onFiltersChange, + converterOptions: ['Base64Converter', 'ROT13Converter'], + } + + render( + + + + ) + + fireEvent.click(screen.getByTestId('converter-filter')) + const option = await screen.findByText('(No converters)') + fireEvent.click(option) + + expect(onFiltersChange).toHaveBeenCalledWith( + expect.objectContaining({ converter: [], hasConverters: false }) + ) + }) + + it('should replace sentinel with real converter when user picks a converter while sentinel is active', async () => { + const onFiltersChange = jest.fn() + const props = { + ...defaultProps, + onFiltersChange, + converterOptions: ['Base64Converter', 'ROT13Converter'], + filters: { ...DEFAULT_HISTORY_FILTERS, hasConverters: false as boolean | undefined }, + } + + render( + + + + ) + + fireEvent.click(screen.getByTestId('converter-filter')) + const option = await screen.findByText('ROT13Converter') + fireEvent.click(option) + + expect(onFiltersChange).toHaveBeenCalledWith( + expect.objectContaining({ converter: ['ROT13Converter'], hasConverters: undefined }) + ) + }) + + it('should replace real converters with sentinel when user picks sentinel while real converters are active', async () => { + const onFiltersChange = jest.fn() + const props = { + ...defaultProps, + onFiltersChange, + converterOptions: ['Base64Converter', 'ROT13Converter'], + filters: { ...DEFAULT_HISTORY_FILTERS, converter: ['Base64Converter'] }, + } + + render( + + + + ) + + fireEvent.click(screen.getByTestId('converter-filter')) + const option = await screen.findByText('(No converters)') + fireEvent.click(option) + + expect(onFiltersChange).toHaveBeenCalledWith( + expect.objectContaining({ converter: [], hasConverters: false }) + ) + }) + + it('should not render the match-mode toggle when fewer than two converters are selected', () => { + const props = { + ...defaultProps, + converterOptions: ['Base64Converter', 'ROT13Converter'], + filters: { ...DEFAULT_HISTORY_FILTERS, converter: ['Base64Converter'] }, + } + + render( + + + + ) + + expect(screen.queryByTestId('converter-match-mode-toggle')).not.toBeInTheDocument() + }) + + it('should render the match-mode toggle and emit "all" when flipped with two converters selected', () => { + const onFiltersChange = jest.fn() + const props = { + ...defaultProps, + onFiltersChange, + converterOptions: ['Base64Converter', 'ROT13Converter'], + filters: { ...DEFAULT_HISTORY_FILTERS, converter: ['Base64Converter', 'ROT13Converter'] }, + } + + render( + + + + ) + + const toggle = screen.getByTestId('converter-match-mode-toggle') + expect(toggle).toBeInTheDocument() + + fireEvent.click(toggle) + + expect(onFiltersChange).toHaveBeenCalledWith( + expect.objectContaining({ converterMatchMode: 'all' }) + ) + }) }) diff --git a/frontend/src/components/History/HistoryFiltersBar.tsx b/frontend/src/components/History/HistoryFiltersBar.tsx index c62964b97..1f09126f8 100644 --- a/frontend/src/components/History/HistoryFiltersBar.tsx +++ b/frontend/src/components/History/HistoryFiltersBar.tsx @@ -3,7 +3,9 @@ import { Tooltip, Dropdown, Option, + OptionGroup, Combobox, + Switch, } from '@fluentui/react-components' import { FilterRegular, @@ -13,6 +15,8 @@ import { DEFAULT_HISTORY_FILTERS } from './historyFilters' import type { HistoryFilters } from './historyFilters' import { useAttackHistoryStyles } from './AttackHistory.styles' +const NO_CONVERTERS_SENTINEL = '__no_converters__' + interface HistoryFiltersBarProps { filters: HistoryFilters onFiltersChange: (filters: HistoryFilters) => void @@ -35,11 +39,13 @@ export default function HistoryFiltersBar({ const styles = useAttackHistoryStyles() const { - attackClass: attackClassFilter, + attackClasses: attackClassFilters, outcome: outcomeFilter, converter: converterFilter, - operator: operatorFilter, - operation: operationFilter, + converterMatchMode, + hasConverters, + operator: operatorFilters, + operation: operationFilters, otherLabels: otherLabelFilters, labelSearchText, } = filters @@ -49,8 +55,35 @@ export default function HistoryFiltersBar({ } const hasActiveFilters = - attackClassFilter || outcomeFilter || converterFilter || - operatorFilter || operationFilter || otherLabelFilters.length > 0 + attackClassFilters.length > 0 || + outcomeFilter || + converterFilter.length > 0 || + hasConverters !== undefined || + operatorFilters.length > 0 || + operationFilters.length > 0 || + otherLabelFilters.length > 0 + + // Converter Combobox selectedOptions includes the sentinel when hasConverters=false. + const converterSelectedOptions = hasConverters === false + ? [NO_CONVERTERS_SENTINEL] + : converterFilter + + const handleConverterSelect = (selected: string[]) => { + const hasSentinel = selected.includes(NO_CONVERTERS_SENTINEL) + const realConverters = selected.filter((s) => s !== NO_CONVERTERS_SENTINEL) + const sentinelWasActive = hasConverters === false + const sentinelJustAdded = hasSentinel && !sentinelWasActive + + if (sentinelJustAdded || (hasSentinel && realConverters.length === 0)) { + // User just toggled the sentinel on (or it's alone) → clear real converters + onFiltersChange({ ...filters, converter: [], hasConverters: false }) + } else { + // Any real converter toggled (sentinel was active or not) → drop sentinel + onFiltersChange({ ...filters, converter: realConverters, hasConverters: undefined }) + } + } + + const showMatchModeToggle = converterFilter.length >= 2 && hasConverters !== false return (
@@ -68,19 +101,18 @@ export default function HistoryFiltersBar({ )} - setFilter('attackClass', data.optionValue ?? '')} + multiselect + selectedOptions={attackClassFilters} + onOptionSelect={(_e, data) => setFilter('attackClasses', data.selectedOptions)} data-testid="attack-class-filter" > - - {attackClassOptions.map(cls => ( + {attackClassOptions.map((cls) => ( ))} - + Failure - setFilter('converter', data.optionValue ?? '')} + multiselect + selectedOptions={converterSelectedOptions} + onOptionSelect={(_e, data) => handleConverterSelect(data.selectedOptions)} data-testid="converter-filter" > - - {converterOptions.map(c => ( - - ))} - - + + + + {converterOptions.map((c) => ( + + ))} + + + {showMatchModeToggle && ( + + + setFilter('converterMatchMode', data.checked ? 'all' : 'any') + } + data-testid="converter-match-mode-toggle" + /> + + )} + setFilter('operator', data.optionValue ?? '')} + multiselect + selectedOptions={operatorFilters} + onOptionSelect={(_e, data) => setFilter('operator', data.selectedOptions)} data-testid="operator-filter" > - - {operatorOptions.map(o => ( + {operatorOptions.map((o) => ( ))} - - + setFilter('operation', data.optionValue ?? '')} + multiselect + selectedOptions={operationFilters} + onOptionSelect={(_e, data) => setFilter('operation', data.selectedOptions)} data-testid="operation-filter" > - - {operationOptions.map(o => ( + {operationOptions.map((o) => ( ))} - + Date: Wed, 22 Apr 2026 18:04:46 -0400 Subject: [PATCH 4/8] Memory: preserve back-compat (attack_class param, converter_classes=[] semantic) Restores the deprecated singular `attack_class` parameter as a forwarder to `attack_classes` and reverts `converter_classes=[]` to its original "no converters" meaning. Library callers see no behavior change; the new `attack_classes`/`has_converters` params remain available for explicit multi/presence filtering. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/memory_interface.py | 27 +++++++++++---- .../test_interface_attack_results.py | 33 +++++++++++++++---- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index eb4d31bd4..19c18f51c 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1458,6 +1458,7 @@ def get_attack_results( objective: Optional[str] = None, objective_sha256: Optional[Sequence[str]] = None, outcome: Optional[str] = None, + attack_class: Optional[str] = None, attack_classes: Optional[Sequence[str]] = None, converter_classes: Optional[Sequence[str]] = None, has_converters: Optional[bool] = None, @@ -1476,14 +1477,17 @@ def get_attack_results( Defaults to None. outcome (Optional[str], optional): The outcome to filter by (success, failure, undetermined). Defaults to None. + attack_class (Optional[str], optional): Deprecated. Filter by a single exact attack + class_name in attack_identifier. Equivalent to passing ``attack_classes=[attack_class]``. + Cannot be combined with ``attack_classes``. Defaults to None. attack_classes (Optional[Sequence[str]], optional): Filter by exact attack class_name in attack_identifier. Returns attacks matching ANY of the listed class names (OR logic, case-sensitive). An empty sequence applies no filter. Defaults to None. converter_classes (Optional[Sequence[str]], optional): Filter by converter class names. Returns only attacks that used ALL specified converters (AND logic, case-insensitive). - An empty sequence applies no filter. To filter by presence/absence of any converter, - use the ``has_converters`` parameter instead. - Defaults to None. + An empty sequence filters to attacks that used no converters; ``None`` applies no + filter. To filter by presence/absence of any converter explicitly, use the + ``has_converters`` parameter instead. Defaults to None. has_converters (Optional[bool], optional): Filter by converter presence. ``True`` returns only attacks that used at least one converter. ``False`` returns only attacks that used no converters. ``None`` applies no filter. Defaults to None. @@ -1508,9 +1512,19 @@ def get_attack_results( Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. + + Raises: + ValueError: If both ``attack_class`` (deprecated) and ``attack_classes`` are provided. """ conditions: list[ColumnElement[bool]] = [] + if attack_class is not None and attack_classes is not None: + raise ValueError( + "Pass either `attack_class` (deprecated, singular) or `attack_classes` (plural), not both." + ) + if attack_class is not None and attack_classes is None: + attack_classes = [attack_class] + if attack_result_ids is not None: if len(attack_result_ids) == 0: # Empty list means no results @@ -1542,9 +1556,10 @@ def get_attack_results( ) ) - if converter_classes: - # Non-empty sequence filters by converter class names (ALL must be present). - # Empty sequence or None applies no filter. + if converter_classes is not None: + # Non-empty sequence: filter to attacks that used ALL listed converters. + # Empty sequence: filter to attacks that used NO converters. + # None: no filter. conditions.append( self._get_condition_json_array_match( json_column=AttackResultEntry.atomic_attack_identifier, diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index dadd67767..d812e222a 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -5,6 +5,8 @@ import uuid from typing import TYPE_CHECKING, Optional +import pytest + from pyrit.common.utils import to_sha256 from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier @@ -1272,8 +1274,8 @@ def test_get_attack_results_converter_classes_none_returns_all(sqlite_instance: assert len(results) == 3 -def test_get_attack_results_converter_classes_empty_returns_all(sqlite_instance: MemoryInterface): - """Test that converter_classes=[] behaves like None (no filter applied).""" +def test_get_attack_results_converter_classes_empty_matches_no_converters(sqlite_instance: MemoryInterface): + """Test that converter_classes=[] returns only attacks with no converters (back-compat).""" ar_with_conv = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) ar_no_conv_none = _make_attack_result_with_identifier("conv_2", "Attack") # converter_ids=None ar_no_conv_empty = _make_attack_result_with_identifier("conv_3", "Attack", []) # converter_ids=[] @@ -1283,7 +1285,8 @@ def test_get_attack_results_converter_classes_empty_returns_all(sqlite_instance: ) results = sqlite_instance.get_attack_results(converter_classes=[]) - assert len(results) == 4 + conv_ids = {r.conversation_id for r in results} + assert conv_ids == {"conv_2", "conv_3", "conv_4"} def test_get_attack_results_converter_classes_single_match(sqlite_instance: MemoryInterface): @@ -1346,15 +1349,33 @@ def test_get_attack_results_attack_classes_and_converter_classes_combined(sqlite assert results[0].conversation_id == "conv_1" -def test_get_attack_results_attack_classes_converter_classes_empty_no_filter(sqlite_instance: MemoryInterface): - """Combining attack_classes with converter_classes=[] applies no converter filter.""" +def test_get_attack_results_attack_classes_converter_classes_empty_matches_no_converters( + sqlite_instance: MemoryInterface, +): + """Combining attack_classes with converter_classes=[] restricts to the class with no converters.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") # No converters ar3 = _make_attack_result_with_identifier("conv_3", "ManualAttack") # Different class sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) results = sqlite_instance.get_attack_results(attack_classes=["CrescendoAttack"], converter_classes=[]) - assert {r.conversation_id for r in results} == {"conv_1", "conv_2"} + assert {r.conversation_id for r in results} == {"conv_2"} + + +def test_get_attack_results_attack_class_backcompat_singular(sqlite_instance: MemoryInterface): + """Deprecated singular attack_class=... still works and is equivalent to attack_classes=[...].""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack") + assert {r.conversation_id for r in results} == {"conv_1"} + + +def test_get_attack_results_attack_class_and_attack_classes_both_raises(sqlite_instance: MemoryInterface): + """Passing both attack_class and attack_classes is rejected.""" + with pytest.raises(ValueError, match="attack_class"): + sqlite_instance.get_attack_results(attack_class="A", attack_classes=["B"]) def test_get_attack_results_has_converters_true(sqlite_instance: MemoryInterface): From 8b8aaee392f75dd4da055b4b5a4a3655ddfe9f1d Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Wed, 22 Apr 2026 18:04:56 -0400 Subject: [PATCH 5/8] Frontend: fix multi-select display and outcome filter persistence Multi-select Comboboxes now display " (+N)" instead of always showing the placeholder. Swaps the outcome filter from Dropdown to single-select Combobox so its visible label persists across re-renders triggered by sibling filter changes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../History/HistoryFiltersBar.test.tsx | 62 +++++++++++++++++++ .../components/History/HistoryFiltersBar.tsx | 33 ++++++++-- 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/frontend/src/components/History/HistoryFiltersBar.test.tsx b/frontend/src/components/History/HistoryFiltersBar.test.tsx index 48f2ae4f5..b2e15bdd0 100644 --- a/frontend/src/components/History/HistoryFiltersBar.test.tsx +++ b/frontend/src/components/History/HistoryFiltersBar.test.tsx @@ -368,4 +368,66 @@ describe('HistoryFiltersBar', () => { expect.objectContaining({ converterMatchMode: 'all' }) ) }) + + it('should display "name (+N)" when multiple values are selected in a multi-select Combobox', () => { + const props = { + ...defaultProps, + operatorOptions: ['alice', 'bob', 'carol'], + filters: { ...DEFAULT_HISTORY_FILTERS, operator: ['alice', 'bob', 'carol'] }, + } + + render( + + + + ) + + const input = screen.getByTestId('operator-filter') as HTMLInputElement + expect(input.value).toBe('alice (+2)') + }) + + it('should display just the name when exactly one value is selected in a multi-select Combobox', () => { + const props = { + ...defaultProps, + operatorOptions: ['alice', 'bob'], + filters: { ...DEFAULT_HISTORY_FILTERS, operator: ['alice'] }, + } + + render( + + + + ) + + const input = screen.getByTestId('operator-filter') as HTMLInputElement + expect(input.value).toBe('alice') + }) + + it('should keep the outcome Dropdown display label after a sibling filter changes', () => { + const props = { + ...defaultProps, + attackClassOptions: ['PromptSendingAttack'], + filters: { ...DEFAULT_HISTORY_FILTERS, outcome: 'success' }, + } + + const { rerender } = render( + + + + ) + + const outcomeInput = screen.getByTestId('outcome-filter') as HTMLInputElement + expect(outcomeInput.value).toBe('Success') + + rerender( + + + + ) + + expect((screen.getByTestId('outcome-filter') as HTMLInputElement).value).toBe('Success') + }) }) diff --git a/frontend/src/components/History/HistoryFiltersBar.tsx b/frontend/src/components/History/HistoryFiltersBar.tsx index 1f09126f8..1472c7a6a 100644 --- a/frontend/src/components/History/HistoryFiltersBar.tsx +++ b/frontend/src/components/History/HistoryFiltersBar.tsx @@ -1,7 +1,6 @@ import { Button, Tooltip, - Dropdown, Option, OptionGroup, Combobox, @@ -17,6 +16,20 @@ import { useAttackHistoryStyles } from './AttackHistory.styles' const NO_CONVERTERS_SENTINEL = '__no_converters__' +const OUTCOME_LABELS: Record = { + success: 'Success', + failure: 'Failure', + undetermined: 'Undetermined', +} + +// Fluent's multiselect Combobox doesn't auto-populate its input from selectedOptions; +// we have to drive the displayed text via `value` ourselves. +function formatMultiSelectValue(selected: string[]): string { + if (selected.length === 0) return '' + if (selected.length === 1) return selected[0] + return `${selected[0]} (+${selected.length - 1})` +} + interface HistoryFiltersBarProps { filters: HistoryFilters onFiltersChange: (filters: HistoryFilters) => void @@ -106,6 +119,7 @@ export default function HistoryFiltersBar({ placeholder="All attack types" multiselect selectedOptions={attackClassFilters} + value={formatMultiSelectValue(attackClassFilters)} onOptionSelect={(_e, data) => setFilter('attackClasses', data.selectedOptions)} data-testid="attack-class-filter" > @@ -113,24 +127,31 @@ export default function HistoryFiltersBar({ ))} - setFilter('outcome', data.optionValue ?? '')} + onOptionSelect={(_e, data) => + setFilter('outcome', data.selectedOptions[0] ?? '') + } data-testid="outcome-filter" > - + handleConverterSelect(data.selectedOptions)} data-testid="converter-filter" > @@ -167,6 +188,7 @@ export default function HistoryFiltersBar({ placeholder="All operators" multiselect selectedOptions={operatorFilters} + value={formatMultiSelectValue(operatorFilters)} onOptionSelect={(_e, data) => setFilter('operator', data.selectedOptions)} data-testid="operator-filter" > @@ -179,6 +201,7 @@ export default function HistoryFiltersBar({ placeholder="All operations" multiselect selectedOptions={operationFilters} + value={formatMultiSelectValue(operationFilters)} onOptionSelect={(_e, data) => setFilter('operation', data.selectedOptions)} data-testid="operation-filter" > From d62485e5a2e00cb83ac598f25499e98b68c3a300 Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Wed, 22 Apr 2026 18:54:22 -0400 Subject: [PATCH 6/8] Tests + review fixes: coerce empty converter_types, skip empty-sequence label keys - Service layer coerces converter_types=[] to 'no filter'; the 'no converters' intent is expressed via has_converters=False. Keeps route/service/memory semantics consistent. - memory_interface.get_attack_results strips label keys whose value is an empty sequence before emitting the label condition (previously produced a strictly-more-restrictive EXISTS predicate). - Skip the redundant has_converters=True predicate when converter_classes is already non-empty. - Adds unit tests for the empty-sequence/no-identifier branches in the memory and backend service layers. - Fixes two Attack History e2e specs to click the Combobox input (Fluent v9 multiselect doesn't open the listbox on wrapper click). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- frontend/e2e/history.spec.ts | 8 ++--- pyrit/backend/routes/attacks.py | 4 +-- pyrit/backend/services/attack_service.py | 23 +++++++++---- pyrit/memory/memory_interface.py | 20 +++++++++--- tests/unit/backend/test_attack_service.py | 22 +++++++++++-- .../test_interface_attack_results.py | 31 ++++++++++++++++++ tests/unit/memory/test_azure_sql_memory.py | 32 +++++++++++++++++++ uv.lock | 2 +- 8 files changed, 120 insertions(+), 22 deletions(-) diff --git a/frontend/e2e/history.spec.ts b/frontend/e2e/history.spec.ts index bbee24502..e8ef93b75 100644 --- a/frontend/e2e/history.spec.ts +++ b/frontend/e2e/history.spec.ts @@ -210,9 +210,9 @@ test.describe("Attack History Filters", () => { await expect(page.getByTestId("attack-row-atk-alice-a")).toBeVisible(); await expect(page.getByTestId("attack-row-atk-bob-b")).toBeVisible(); - // Open the attack class dropdown and select SingleTurnAttack + // Open the attack class combobox (multiselect: must click the input, not the wrapper) const dropdown = page.getByTestId("attack-class-filter"); - await dropdown.click(); + await dropdown.locator("input").click(); await page.getByRole("option", { name: "SingleTurnAttack" }).click(); // Only SingleTurnAttack attacks should be visible @@ -243,9 +243,9 @@ test.describe("Attack History Filters", () => { await mockHistoryAPIs(page); await goToHistory(page); - // Open operator dropdown and select "bob" + // Open the operator combobox (multiselect: must click the input, not the wrapper) const operatorDropdown = page.getByTestId("operator-filter"); - await operatorDropdown.click(); + await operatorDropdown.locator("input").click(); await page.getByRole("option", { name: "bob" }).click(); // Only bob's attacks diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index ed06360a3..4ab0ce5a6 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -78,8 +78,8 @@ async def list_attacks( None, description="Filter by converter type names (repeatable). " "Combined per converter_types_match. " - "Omit to return all attacks regardless of converters. " - "To filter for attacks with no converters, use has_converters=false.", + "Omit (or pass an empty value) to apply no converter filter. " + "To restrict to attacks with no converters, use has_converters=false.", ), converter_types_match: Literal["any", "all"] = Query( "all", diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 9f5a51a71..05af2ba68 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -111,9 +111,10 @@ async def list_attacks_async( attack_types: Filter by attack type names (repeatable, OR-matched, case-sensitive). None or empty list applies no filter. converter_types: Filter by converter class names (case-insensitive). - ``None`` or empty list applies no filter. Combination semantics for - multiple entries are controlled by ``converter_types_match``. For the - "only attacks with no converters" case, use ``has_converters=False``. + ``None`` or an empty list applies no filter at this layer. Combination + semantics for multiple entries are controlled by ``converter_types_match``. + To restrict results to attacks with no converters, pass + ``has_converters=False`` instead. converter_types_match: How to combine multiple entries in ``converter_types``. ``"all"`` (default) matches attacks that used every listed converter. ``"any"`` matches attacks that used at least one of the listed converters. @@ -134,20 +135,28 @@ async def list_attacks_async( AttackListResponse with filtered and paginated attack summaries. """ # Phase 1: Query + lightweight filtering (no pieces needed) + # Coerce an empty converter_types list to None so it behaves as "no filter" at + # this layer — the "attacks with no converters" case is expressed through + # has_converters=False, which keeps the three layers (route/service/memory) + # consistent. + effective_converter_types = converter_types if converter_types else None + # For converter_types, push AND-logic down to the DB for "all" mode and for - # the degenerate cases (None, [], single entry). For "any" mode with 2+ + # the degenerate cases (None, single entry). For "any" mode with 2+ # entries, skip the DB filter and apply OR-matching in Python below. - push_converter_filter = converter_types is None or len(converter_types) <= 1 or converter_types_match == "all" + push_converter_filter = ( + effective_converter_types is None or len(effective_converter_types) == 1 or converter_types_match == "all" + ) attack_results = self._memory.get_attack_results( outcome=outcome, labels=labels if labels else None, attack_classes=attack_types if attack_types else None, - converter_classes=converter_types if push_converter_filter else None, + converter_classes=effective_converter_types if push_converter_filter else None, has_converters=has_converters, ) if not push_converter_filter: - requested = {c.lower() for c in converter_types or []} + requested = {c.lower() for c in effective_converter_types or []} attack_results = [ar for ar in attack_results if _converter_classes_on_attack(ar) & requested] filtered: list[AttackResult] = [] diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 19c18f51c..7bed595b9 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1569,9 +1569,12 @@ def get_attack_results( ) ) - if has_converters is not None: - # Reuse the array-empty match (array_to_match=[]) as the "no converters" condition; - # invert it for "has at least one converter". + # Skip when has_converters=True and converter_classes is already non-empty: + # the "ALL listed converters" constraint strictly implies "at least one + # converter", so adding this predicate is redundant work. + if has_converters is not None and not (has_converters is True and converter_classes): + # Reuse the array-empty match (array_to_match=[]) as the "no converters" + # condition; invert it for "has at least one converter". empty_condition = self._get_condition_json_array_match( json_column=AttackResultEntry.atomic_attack_identifier, property_path="$.children.attack_technique.children.attack.children.request_converters", @@ -1592,8 +1595,15 @@ def get_attack_results( ) if labels: - # Use database-specific JSON query method - conditions.append(self._get_attack_result_label_condition(labels=labels)) + # Strip keys whose value is an empty sequence — an empty sequence means + # "no OR-candidates", and per the docstring applies no filter for that + # key. Without this, the per-backend helpers would still emit a base + # EXISTS(... labels IS NOT NULL) predicate that is strictly more + # restrictive than "no filter". + effective_labels = {k: v for k, v in labels.items() if not (isinstance(v, (list, tuple)) and len(v) == 0)} + if effective_labels: + # Use database-specific JSON query method + conditions.append(self._get_attack_result_label_condition(labels=effective_labels)) if identifier_filters: conditions.extend( diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index e9f98cbb9..a08574196 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -22,6 +22,7 @@ ) from pyrit.backend.services.attack_service import ( AttackService, + _converter_classes_on_attack, get_attack_service, ) from pyrit.identifiers import ComponentIdentifier @@ -244,15 +245,20 @@ async def test_list_attacks_attack_types_empty_list_coerced_to_none(self, attack assert call_kwargs["attack_classes"] is None @pytest.mark.asyncio - async def test_list_attacks_filters_by_no_converters(self, attack_service, mock_memory) -> None: - """Test that converter_types=[] is forwarded to memory for DB-level filtering.""" + async def test_list_attacks_coerces_empty_converter_types_to_no_filter(self, attack_service, mock_memory) -> None: + """converter_types=[] at the service boundary means 'no converter filter'. + + The 'attacks with no converters' intent is expressed via has_converters=False; + an empty list is coerced to None so route/service/memory stay consistent. + """ mock_memory.get_attack_results.return_value = [] mock_memory.get_message_pieces.return_value = [] await attack_service.list_attacks_async(converter_types=[]) call_kwargs = mock_memory.get_attack_results.call_args[1] - assert call_kwargs["converter_classes"] == [] + assert call_kwargs["converter_classes"] is None + assert call_kwargs["has_converters"] is None @pytest.mark.asyncio async def test_list_attacks_forwards_has_converters_true(self, attack_service, mock_memory) -> None: @@ -2370,6 +2376,16 @@ def test_duplicate_conversation_remaps_assistant_to_simulated(self, attack_servi assert dup_piece._role == "simulated_assistant" + def test_converter_classes_on_attack_returns_empty_when_no_strategy_identifier(self): + """_converter_classes_on_attack returns an empty set when the attack has no strategy identifier.""" + ar = AttackResult( + conversation_id="conv-1", + objective="obj", + atomic_attack_identifier=None, + outcome=AttackOutcome.UNDETERMINED, + ) + assert _converter_classes_on_attack(ar) == set() + class TestAddMessageGuards: """Tests for target-mismatch and operator-mismatch guards in add_message_async.""" diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index d812e222a..2e67b50e9 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -807,6 +807,37 @@ def test_get_attack_results_by_labels_single(sqlite_instance: MemoryInterface): assert conversation_ids == {"conv_1", "conv_3"} +def test_get_attack_results_by_labels_empty_sequence_value_skips_key(sqlite_instance: MemoryInterface): + """An empty sequence for a label key is skipped (no filter applied for that key). + + Includes an attack whose prompt has ``labels=None`` and another with non-matching + labels (no ``operator`` key at all) to guard against a regression where the filter + silently adds an "EXISTS(... labels IS NOT NULL)" constraint when all values for a + key are empty. + """ + mp1 = create_message_piece("conv_1", 1, labels={"operator": "roakey"}) + mp2 = create_message_piece("conv_2", 2, labels={"operator": "alice"}) + mp3 = create_message_piece("conv_3", 3, labels={"phase": "initial"}) + mp4 = create_message_piece("conv_4", 4, labels=None) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[mp1, mp2, mp3, mp4]) + + ar1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) + ar2 = create_attack_result("conv_2", 2, AttackOutcome.SUCCESS) + ar3 = create_attack_result("conv_3", 3, AttackOutcome.SUCCESS) + ar4 = create_attack_result("conv_4", 4, AttackOutcome.SUCCESS) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3, ar4]) + + # Empty sequence for "operator" → filter is ignored entirely, all attacks match + # (including conv_3 which has no "operator" key and conv_4 which has no labels). + results = sqlite_instance.get_attack_results(labels={"operator": []}) + assert {r.conversation_id for r in results} == {"conv_1", "conv_2", "conv_3", "conv_4"} + + # Mixed: one key with empty-sequence (ignored) + one real filter should behave + # exactly like the real filter alone. + mixed = sqlite_instance.get_attack_results(labels={"operator": [], "phase": "initial"}) + assert {r.conversation_id for r in mixed} == {"conv_3"} + + def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface): """Test filtering attack results by multiple labels (AND logic).""" diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 4d9e056c5..91ed52a06 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -354,6 +354,38 @@ def test_get_condition_json_property_match_bind_params( assert list(mv_params.values())[0] == expected_value +def test_get_attack_result_label_condition_with_string_value(memory_interface: AzureSQLMemory): + """String values produce a single-placeholder IN clause with the stringified value.""" + condition = memory_interface._get_attack_result_label_condition(labels={"operator": "roakey"}) + params = condition.compile().params + assert params.get("label_operator_0") == "roakey" + + +def test_get_attack_result_label_condition_with_sequence_value(memory_interface: AzureSQLMemory): + """Sequence values produce one placeholder per element.""" + condition = memory_interface._get_attack_result_label_condition(labels={"operation": ["op_a", "op_b", "op_c"]}) + params = condition.compile().params + assert params.get("label_operation_0") == "op_a" + assert params.get("label_operation_1") == "op_b" + assert params.get("label_operation_2") == "op_c" + + +def test_get_attack_result_label_condition_skips_empty_sequence(memory_interface: AzureSQLMemory): + """Empty sequence values are skipped (no filter applied for that key).""" + condition = memory_interface._get_attack_result_label_condition(labels={"operator": "roakey", "operation": []}) + params = condition.compile().params + # operator gets a bind param; operation (empty) does not. + assert params.get("label_operator_0") == "roakey" + assert not any(k.startswith("label_operation_") for k in params) + + +def test_get_attack_result_label_condition_empty_labels_dict(memory_interface: AzureSQLMemory): + """An empty labels dict produces a condition with no label filters bound.""" + condition = memory_interface._get_attack_result_label_condition(labels={}) + params = condition.compile().params + assert not any(k.startswith("label_") for k in params) + + @pytest.mark.parametrize( "case_sensitive, partial_match, expected_sql_fragment", [ diff --git a/uv.lock b/uv.lock index 21952cbdc..e4c484965 100644 --- a/uv.lock +++ b/uv.lock @@ -4895,7 +4895,7 @@ wheels = [ [[package]] name = "pyrit" -version = "0.13.0.dev0" +version = "0.14.0.dev0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, From deeb662d1422ae3b1e291c6f4de4532ddfcab692 Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Wed, 22 Apr 2026 19:16:24 -0400 Subject: [PATCH 7/8] e2e: fix multiselect Combobox role + attack_types mock param - Fluent v9 multiselect Combobox renders items as role=menuitemcheckbox, not role=option. Update attack-class and operator filter tests to query the correct role. - Update the mockHistoryAPIs handler to read the renamed attack_types (plural, repeatable) query param and to group repeated label keys into OR-sets (matches the route's real semantic). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- frontend/e2e/history.spec.ts | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/frontend/e2e/history.spec.ts b/frontend/e2e/history.spec.ts index e8ef93b75..bc25c0880 100644 --- a/frontend/e2e/history.spec.ts +++ b/frontend/e2e/history.spec.ts @@ -161,23 +161,29 @@ async function mockHistoryAPIs( return; } const url = new URL(route.request().url()); - const attackType = url.searchParams.get("attack_type"); + const attackTypeParams = url.searchParams.getAll("attack_types"); const outcome = url.searchParams.get("outcome"); const labelParams = url.searchParams.getAll("label"); let filtered = [...attacks]; - if (attackType) { - filtered = filtered.filter((a) => a.attack_type === attackType); + if (attackTypeParams.length > 0) { + filtered = filtered.filter((a) => attackTypeParams.includes(a.attack_type)); } if (outcome) { filtered = filtered.filter((a) => a.outcome === outcome); } if (labelParams.length > 0) { + // Group repeated label keys into OR-sets; combine across keys with AND. + const grouped = new Map(); + for (const lp of labelParams) { + const [key, val] = lp.split(":"); + if (!key) continue; + const bucket = grouped.get(key) ?? []; + bucket.push(val ?? ""); + grouped.set(key, bucket); + } filtered = filtered.filter((a) => - labelParams.every((lp) => { - const [key, val] = lp.split(":"); - return a.labels[key] === val; - }), + Array.from(grouped.entries()).every(([key, vals]) => vals.includes(a.labels[key] ?? "")), ); } @@ -210,10 +216,11 @@ test.describe("Attack History Filters", () => { await expect(page.getByTestId("attack-row-atk-alice-a")).toBeVisible(); await expect(page.getByTestId("attack-row-atk-bob-b")).toBeVisible(); - // Open the attack class combobox (multiselect: must click the input, not the wrapper) + // Open the attack class combobox. Multiselect Combobox renders items as + // role="menuitemcheckbox", not role="option". const dropdown = page.getByTestId("attack-class-filter"); - await dropdown.locator("input").click(); - await page.getByRole("option", { name: "SingleTurnAttack" }).click(); + await dropdown.click(); + await page.getByRole("menuitemcheckbox", { name: "SingleTurnAttack" }).click(); // Only SingleTurnAttack attacks should be visible await expect(page.getByTestId("attack-row-atk-alice-a")).toBeVisible({ timeout: 5_000 }); @@ -243,10 +250,11 @@ test.describe("Attack History Filters", () => { await mockHistoryAPIs(page); await goToHistory(page); - // Open the operator combobox (multiselect: must click the input, not the wrapper) + // Open the operator combobox. Multiselect Combobox renders items as + // role="menuitemcheckbox", not role="option". const operatorDropdown = page.getByTestId("operator-filter"); - await operatorDropdown.locator("input").click(); - await page.getByRole("option", { name: "bob" }).click(); + await operatorDropdown.click(); + await page.getByRole("menuitemcheckbox", { name: "bob" }).click(); // Only bob's attacks await expect(page.getByTestId("attack-row-atk-bob-b")).toBeVisible({ timeout: 5_000 }); From 8161d022143619b59a14ff1aea072dc06e8ec382 Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Thu, 23 Apr 2026 10:57:42 -0400 Subject: [PATCH 8/8] Address Copilot review findings on PR #1643 - Validate label keys against [A-Za-z0-9_.-]+ allowlist in MemoryInterface.get_attack_results before interpolation into backend JSON-path expressions, preventing SQL injection via crafted label keys (e.g. Azure SQL's text(f\$.{key}) fragment). - Make multi-select Attack History filters actually searchable: extract a SearchableMultiCombobox helper that tracks open + search state, filters options by typed text, and shows the formatted 'name (+N)' summary only while the popover is closed. Apply the same pattern inline to the converter filter (keeps its sentinel OptionGroup). - Strip empty strings from ?attack_types= query param for symmetry with ?converter_types=, and clarify the route-layer comment now that the service coerces [] to None. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../History/HistoryFiltersBar.test.tsx | 30 ++++ .../components/History/HistoryFiltersBar.tsx | 133 +++++++++++++----- pyrit/backend/routes/attacks.py | 7 +- pyrit/memory/memory_interface.py | 17 +++ .../test_interface_attack_results.py | 18 +++ 5 files changed, 167 insertions(+), 38 deletions(-) diff --git a/frontend/src/components/History/HistoryFiltersBar.test.tsx b/frontend/src/components/History/HistoryFiltersBar.test.tsx index b2e15bdd0..9503fb9f3 100644 --- a/frontend/src/components/History/HistoryFiltersBar.test.tsx +++ b/frontend/src/components/History/HistoryFiltersBar.test.tsx @@ -403,6 +403,36 @@ describe('HistoryFiltersBar', () => { expect(input.value).toBe('alice') }) + it('should filter multi-select Combobox options by typed search text and reset search on selection', async () => { + const onFiltersChange = jest.fn() + const props = { + ...defaultProps, + onFiltersChange, + operatorOptions: ['alice', 'bob', 'carol'], + } + + render( + + + + ) + + const operatorInput = screen.getByTestId('operator-filter') as HTMLInputElement + // Open and type a search; only matching options should remain visible. + fireEvent.click(operatorInput) + fireEvent.change(operatorInput, { target: { value: 'car' } }) + + expect(operatorInput.value).toBe('car') + expect(screen.queryByText('alice')).not.toBeInTheDocument() + expect(screen.queryByText('bob')).not.toBeInTheDocument() + const match = await screen.findByText('carol') + fireEvent.click(match) + + expect(onFiltersChange).toHaveBeenCalledWith( + expect.objectContaining({ operator: ['carol'] }) + ) + }) + it('should keep the outcome Dropdown display label after a sibling filter changes', () => { const props = { ...defaultProps, diff --git a/frontend/src/components/History/HistoryFiltersBar.tsx b/frontend/src/components/History/HistoryFiltersBar.tsx index 1472c7a6a..049056fad 100644 --- a/frontend/src/components/History/HistoryFiltersBar.tsx +++ b/frontend/src/components/History/HistoryFiltersBar.tsx @@ -1,3 +1,4 @@ +import { useState } from 'react' import { Button, Tooltip, @@ -23,13 +24,68 @@ const OUTCOME_LABELS: Record = { } // Fluent's multiselect Combobox doesn't auto-populate its input from selectedOptions; -// we have to drive the displayed text via `value` ourselves. +// we have to drive the displayed text via `value` ourselves. This format is only +// shown when the popover is closed — while it's open we show the user's typed +// search text so typing-to-search actually works. function formatMultiSelectValue(selected: string[]): string { if (selected.length === 0) return '' if (selected.length === 1) return selected[0] return `${selected[0]} (+${selected.length - 1})` } +interface SearchableMultiComboboxProps { + placeholder: string + selectedOptions: string[] + options: string[] + onSelect: (selected: string[]) => void + testid: string + className?: string +} + +function SearchableMultiCombobox({ + placeholder, + selectedOptions, + options, + onSelect, + testid, + className, +}: SearchableMultiComboboxProps) { + const [open, setOpen] = useState(false) + const [search, setSearch] = useState('') + + const filtered = search + ? options.filter((o) => o.toLowerCase().includes(search.toLowerCase())) + : options + + return ( + { + setOpen(data.open) + // Clear search text both when opening (start fresh) and when closing + // (so the formatted value is shown again). + setSearch('') + }} + selectedOptions={selectedOptions} + value={open ? search : formatMultiSelectValue(selectedOptions)} + onChange={(e) => setSearch((e.target as HTMLInputElement).value)} + onOptionSelect={(_e, data) => { + onSelect(data.selectedOptions) + setSearch('') + }} + data-testid={testid} + > + {filtered.map((o) => ( + + ))} + + ) +} + interface HistoryFiltersBarProps { filters: HistoryFilters onFiltersChange: (filters: HistoryFilters) => void @@ -98,6 +154,14 @@ export default function HistoryFiltersBar({ const showMatchModeToggle = converterFilter.length >= 2 && hasConverters !== false + // Searchable state for the converter Combobox (custom because of the + // "(No converters)" sentinel OptionGroup). + const [converterOpen, setConverterOpen] = useState(false) + const [converterSearch, setConverterSearch] = useState('') + const converterMatchesSearch = (c: string) => + !converterSearch || c.toLowerCase().includes(converterSearch.toLowerCase()) + const filteredConverterOptions = converterOptions.filter(converterMatchesSearch) + return (
@@ -114,19 +178,14 @@ export default function HistoryFiltersBar({ )} - setFilter('attackClasses', data.selectedOptions)} - data-testid="attack-class-filter" - > - {attackClassOptions.map((cls) => ( - - ))} - + options={attackClassOptions} + onSelect={(selected) => setFilter('attackClasses', selected)} + testid="attack-class-filter" + /> { + setConverterOpen(data.open) + setConverterSearch('') + }} selectedOptions={converterSelectedOptions} value={ - hasConverters === false - ? '(No converters)' - : formatMultiSelectValue(converterFilter) + converterOpen + ? converterSearch + : hasConverters === false + ? '(No converters)' + : formatMultiSelectValue(converterFilter) } - onOptionSelect={(_e, data) => handleConverterSelect(data.selectedOptions)} + onChange={(e) => setConverterSearch((e.target as HTMLInputElement).value)} + onOptionSelect={(_e, data) => { + handleConverterSelect(data.selectedOptions) + setConverterSearch('') + }} data-testid="converter-filter" > - {converterOptions.map((c) => ( + {filteredConverterOptions.map((c) => ( ))} @@ -183,32 +254,22 @@ export default function HistoryFiltersBar({ /> )} - setFilter('operator', data.selectedOptions)} - data-testid="operator-filter" - > - {operatorOptions.map((o) => ( - - ))} - - setFilter('operator', selected)} + testid="operator-filter" + /> + setFilter('operation', data.selectedOptions)} - data-testid="operation-filter" - > - {operationOptions.map((o) => ( - - ))} - + options={operationOptions} + onSelect={(selected) => setFilter('operation', selected)} + testid="operation-filter" + />