From 081ef6188e2fc1a59a4575d8e8b2f692aeb9ead2 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 22 Apr 2026 15:46:31 -0700 Subject: [PATCH] Stop using deprecated AttackResult.attack_identifier in attack_service MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace ar.attack_identifier with ar.get_attack_strategy_identifier() in attack_service.py. Remove writing the deprecated attack_identifier DB column in _update_attack_after_message_async — only atomic_attack_identifier (the source of truth) is now updated. Add tests for: - Flat/legacy atomic identifier fallback converter merge path - Deprecated attack_identifier column not written to update_fields - DB read prefers atomic_attack_identifier over stale attack_identifier Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/services/attack_service.py | 6 +- tests/unit/backend/test_attack_service.py | 92 ++++++++++++++++++++++- tests/unit/memory/test_memory_models.py | 18 +++++ 3 files changed, 108 insertions(+), 8 deletions(-) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 5c8ea20b6..b5e8c2b34 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -558,7 +558,7 @@ async def add_message_async(self, *, attack_result_id: str, request: AddMessageR ar = results[0] main_conversation_id = ar.conversation_id - self._validate_target_match(attack_identifier=ar.attack_identifier, request=request) + self._validate_target_match(attack_identifier=ar.get_attack_strategy_identifier(), request=request) self._validate_operator_match(conversation_id=main_conversation_id, request=request) msg_conversation_id = request.target_conversation_id @@ -719,7 +719,7 @@ async def _update_attack_after_message_async( if request.converter_ids: converter_objs = get_converter_service().get_converter_objects_for_ids(converter_ids=request.converter_ids) new_converter_ids = [c.get_identifier() for c in converter_objs] - aid = ar.attack_identifier + aid = ar.get_attack_strategy_identifier() if aid: existing_converters: list[ComponentIdentifier] = list(aid.get_child_list("request_converters")) existing_hashes = {c.hash for c in existing_converters} @@ -733,8 +733,6 @@ async def _update_attack_after_message_async( params=dict(aid.params), children=new_children, ) - update_fields["attack_identifier"] = new_aid.to_dict() - # Also update atomic_attack_identifier so get_attack_strategy_identifier() sees the change if ar.atomic_attack_identifier: atomic = ComponentIdentifier.from_dict(ar.atomic_attack_identifier.to_dict()) atomic_children = dict(atomic.children) diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 3d09fe6b7..d0e02b5fd 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -748,7 +748,7 @@ async def test_create_attack_default_name(self, attack_service, mock_memory) -> call_args = mock_memory.add_attack_results_to_memory.call_args stored_ar = call_args[1]["attack_results"][0] assert stored_ar.objective == "Manual attack via GUI" - assert stored_ar.attack_identifier.class_name == "ManualAttack" + assert stored_ar.get_attack_strategy_identifier().class_name == "ManualAttack" # ============================================================================ @@ -1158,9 +1158,9 @@ async def test_converter_ids_propagate_even_when_preconverted(self, attack_servi # Normalizer should still get empty converter configs since pieces are preconverted call_kwargs = mock_normalizer.send_prompt_async.call_args[1] assert call_kwargs["request_converter_configurations"] == [] - # attack_identifier should be updated with converter identifiers + # atomic_attack_identifier should be updated with converter identifiers update_call = mock_memory.update_attack_result_by_id.call_args[1] - assert "attack_identifier" in update_call["update_fields"] + assert "atomic_attack_identifier" in update_call["update_fields"] @pytest.mark.asyncio async def test_add_message_no_existing_pieces_uses_request_labels(self, attack_service, mock_memory) -> None: @@ -2199,11 +2199,95 @@ async def test_add_message_merges_converter_identifiers_without_duplicates(self, await attack_service.add_message_async(attack_result_id="attack-1", request=request) update_fields = mock_memory.update_attack_result_by_id.call_args[1]["update_fields"] - persisted_identifiers = update_fields["attack_identifier"]["children"]["request_converters"] + # Converters are now stored inside atomic_attack_identifier -> attack_technique -> attack + atomic_id = update_fields["atomic_attack_identifier"] + attack_id = atomic_id["children"]["attack_technique"]["children"]["attack"] + persisted_identifiers = attack_id["children"]["request_converters"] persisted_classes = [identifier["class_name"] for identifier in persisted_identifiers] assert persisted_classes.count("ExistingConverter") == 1 assert persisted_classes.count("NewConverter") == 1 + # The deprecated attack_identifier column should NOT be written + assert "attack_identifier" not in update_fields + + @pytest.mark.asyncio + async def test_converter_merge_with_flat_atomic_identifier(self, attack_service, mock_memory): + """Should merge converters via fallback path when atomic_attack_identifier has no attack_technique child.""" + from pyrit.backend.models.attacks import AttackSummary, ConversationMessagesResponse + + new_converter = ComponentIdentifier( + class_name="NewConverter", + class_module="pyrit.prompt_converter", + params={"supported_input_types": ("text",), "supported_output_types": ("text",)}, + ) + + # Build a flat atomic identifier (no attack_technique nesting — legacy shape) + attack_id = ComponentIdentifier( + class_name="ManualAttack", + class_module="pyrit.backend", + children={ + "objective_target": ComponentIdentifier(class_name="TextTarget", class_module="pyrit.prompt_target"), + }, + ) + ar = make_attack_result(conversation_id="flat-1") + ar.atomic_attack_identifier = ComponentIdentifier( + class_name="AtomicAttack", + class_module="pyrit.scenario.core.atomic_attack", + children={"attack": attack_id}, + ) + + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] + + request = AddMessageRequest( + role="user", + pieces=[MessagePieceRequest(original_value="Hello")], + target_conversation_id="flat-1", + send=False, + converter_ids=["c-1"], + ) + + with ( + patch("pyrit.backend.services.attack_service.get_converter_service") as mock_get_converter_service, + patch.object( + attack_service, + "get_attack_async", + new=AsyncMock( + return_value=AttackSummary( + attack_result_id="ar-flat-1", + conversation_id="flat-1", + attack_type="ManualAttack", + converters=[], + message_count=0, + labels={}, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + ), + ), + patch.object( + attack_service, + "get_conversation_messages_async", + new=AsyncMock(return_value=ConversationMessagesResponse(conversation_id="flat-1", messages=[])), + ), + ): + mock_converter_service = MagicMock() + mock_converter_service.get_converter_objects_for_ids.return_value = [ + MagicMock(get_identifier=MagicMock(return_value=new_converter)), + ] + mock_get_converter_service.return_value = mock_converter_service + + await attack_service.add_message_async(attack_result_id="flat-1", request=request) + + update_fields = mock_memory.update_attack_result_by_id.call_args[1]["update_fields"] + assert "atomic_attack_identifier" in update_fields + assert "attack_identifier" not in update_fields + # Flat fallback: converter should be under atomic -> attack -> children + atomic_id = update_fields["atomic_attack_identifier"] + attack_child = atomic_id["children"]["attack"] + persisted_converters = attack_child["children"]["request_converters"] + assert len(persisted_converters) == 1 + assert persisted_converters[0]["class_name"] == "NewConverter" def test_duplicate_conversation_up_to_adds_pieces_when_present(self, attack_service, mock_memory): """Should duplicate up to cutoff and persist duplicated pieces only when returned.""" diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index b74d9c967..0cdb10b6d 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -389,6 +389,24 @@ def test_filter_json_serializable_metadata_mixed(self): assert "int_val" in result assert "non_serializable" not in result + def test_get_attack_result_prefers_atomic_over_stale_attack_identifier(self): + """When atomic_attack_identifier and attack_identifier disagree, atomic wins.""" + from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier + + correct_attack_id = ComponentIdentifier(class_name="CorrectAttack", class_module="pyrit.backend") + atomic_id = build_atomic_attack_identifier(attack_identifier=correct_attack_id) + ar = _make_attack_result(atomic_attack_identifier=atomic_id) + entry = AttackResultEntry(entry=ar) + + # Simulate a stale attack_identifier column (as if it wasn't updated) + stale_id = ComponentIdentifier(class_name="StaleAttack", class_module="pyrit.backend") + entry.attack_identifier = stale_id.to_dict() + + round_tripped = entry.get_attack_result() + strategy = round_tripped.get_attack_strategy_identifier() + assert strategy is not None + assert strategy.class_name == "CorrectAttack" + # --------------------------------------------------------------------------- # ScenarioResultEntry