Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pyrit/backend/services/attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ async def add_message_async(self, *, attack_result_id: str, request: AddMessageR
ar = results[0]
main_conversation_id = ar.conversation_id

self._validate_target_match(attack_identifier=ar.attack_identifier, request=request)
self._validate_target_match(attack_identifier=ar.get_attack_strategy_identifier(), request=request)
self._validate_operator_match(conversation_id=main_conversation_id, request=request)

msg_conversation_id = request.target_conversation_id
Expand Down Expand Up @@ -719,7 +719,7 @@ async def _update_attack_after_message_async(
if request.converter_ids:
converter_objs = get_converter_service().get_converter_objects_for_ids(converter_ids=request.converter_ids)
new_converter_ids = [c.get_identifier() for c in converter_objs]
aid = ar.attack_identifier
aid = ar.get_attack_strategy_identifier()
if aid:
existing_converters: list[ComponentIdentifier] = list(aid.get_child_list("request_converters"))
existing_hashes = {c.hash for c in existing_converters}
Expand All @@ -733,8 +733,6 @@ async def _update_attack_after_message_async(
params=dict(aid.params),
children=new_children,
)
update_fields["attack_identifier"] = new_aid.to_dict()
# Also update atomic_attack_identifier so get_attack_strategy_identifier() sees the change
if ar.atomic_attack_identifier:
atomic = ComponentIdentifier.from_dict(ar.atomic_attack_identifier.to_dict())
atomic_children = dict(atomic.children)
Expand Down
92 changes: 88 additions & 4 deletions tests/unit/backend/test_attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ async def test_create_attack_default_name(self, attack_service, mock_memory) ->
call_args = mock_memory.add_attack_results_to_memory.call_args
stored_ar = call_args[1]["attack_results"][0]
assert stored_ar.objective == "Manual attack via GUI"
assert stored_ar.attack_identifier.class_name == "ManualAttack"
assert stored_ar.get_attack_strategy_identifier().class_name == "ManualAttack"


# ============================================================================
Expand Down Expand Up @@ -1158,9 +1158,9 @@ async def test_converter_ids_propagate_even_when_preconverted(self, attack_servi
# Normalizer should still get empty converter configs since pieces are preconverted
call_kwargs = mock_normalizer.send_prompt_async.call_args[1]
assert call_kwargs["request_converter_configurations"] == []
# attack_identifier should be updated with converter identifiers
# atomic_attack_identifier should be updated with converter identifiers
update_call = mock_memory.update_attack_result_by_id.call_args[1]
assert "attack_identifier" in update_call["update_fields"]
assert "atomic_attack_identifier" in update_call["update_fields"]

@pytest.mark.asyncio
async def test_add_message_no_existing_pieces_uses_request_labels(self, attack_service, mock_memory) -> None:
Expand Down Expand Up @@ -2199,11 +2199,95 @@ async def test_add_message_merges_converter_identifiers_without_duplicates(self,
await attack_service.add_message_async(attack_result_id="attack-1", request=request)

update_fields = mock_memory.update_attack_result_by_id.call_args[1]["update_fields"]
persisted_identifiers = update_fields["attack_identifier"]["children"]["request_converters"]
# Converters are now stored inside atomic_attack_identifier -> attack_technique -> attack
atomic_id = update_fields["atomic_attack_identifier"]
attack_id = atomic_id["children"]["attack_technique"]["children"]["attack"]
persisted_identifiers = attack_id["children"]["request_converters"]
persisted_classes = [identifier["class_name"] for identifier in persisted_identifiers]

assert persisted_classes.count("ExistingConverter") == 1
assert persisted_classes.count("NewConverter") == 1
# The deprecated attack_identifier column should NOT be written
assert "attack_identifier" not in update_fields

@pytest.mark.asyncio
async def test_converter_merge_with_flat_atomic_identifier(self, attack_service, mock_memory):
"""Should merge converters via fallback path when atomic_attack_identifier has no attack_technique child."""
from pyrit.backend.models.attacks import AttackSummary, ConversationMessagesResponse

new_converter = ComponentIdentifier(
class_name="NewConverter",
class_module="pyrit.prompt_converter",
params={"supported_input_types": ("text",), "supported_output_types": ("text",)},
)

# Build a flat atomic identifier (no attack_technique nesting — legacy shape)
attack_id = ComponentIdentifier(
class_name="ManualAttack",
class_module="pyrit.backend",
children={
"objective_target": ComponentIdentifier(class_name="TextTarget", class_module="pyrit.prompt_target"),
},
)
ar = make_attack_result(conversation_id="flat-1")
ar.atomic_attack_identifier = ComponentIdentifier(
class_name="AtomicAttack",
class_module="pyrit.scenario.core.atomic_attack",
children={"attack": attack_id},
)

mock_memory.get_attack_results.return_value = [ar]
mock_memory.get_message_pieces.return_value = []

request = AddMessageRequest(
role="user",
pieces=[MessagePieceRequest(original_value="Hello")],
target_conversation_id="flat-1",
send=False,
converter_ids=["c-1"],
)

with (
patch("pyrit.backend.services.attack_service.get_converter_service") as mock_get_converter_service,
patch.object(
attack_service,
"get_attack_async",
new=AsyncMock(
return_value=AttackSummary(
attack_result_id="ar-flat-1",
conversation_id="flat-1",
attack_type="ManualAttack",
converters=[],
message_count=0,
labels={},
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
),
),
patch.object(
attack_service,
"get_conversation_messages_async",
new=AsyncMock(return_value=ConversationMessagesResponse(conversation_id="flat-1", messages=[])),
),
):
mock_converter_service = MagicMock()
mock_converter_service.get_converter_objects_for_ids.return_value = [
MagicMock(get_identifier=MagicMock(return_value=new_converter)),
]
mock_get_converter_service.return_value = mock_converter_service

await attack_service.add_message_async(attack_result_id="flat-1", request=request)

update_fields = mock_memory.update_attack_result_by_id.call_args[1]["update_fields"]
assert "atomic_attack_identifier" in update_fields
assert "attack_identifier" not in update_fields
# Flat fallback: converter should be under atomic -> attack -> children
atomic_id = update_fields["atomic_attack_identifier"]
attack_child = atomic_id["children"]["attack"]
persisted_converters = attack_child["children"]["request_converters"]
assert len(persisted_converters) == 1
assert persisted_converters[0]["class_name"] == "NewConverter"

def test_duplicate_conversation_up_to_adds_pieces_when_present(self, attack_service, mock_memory):
"""Should duplicate up to cutoff and persist duplicated pieces only when returned."""
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/memory/test_memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,24 @@ def test_filter_json_serializable_metadata_mixed(self):
assert "int_val" in result
assert "non_serializable" not in result

def test_get_attack_result_prefers_atomic_over_stale_attack_identifier(self):
"""When atomic_attack_identifier and attack_identifier disagree, atomic wins."""
from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier

correct_attack_id = ComponentIdentifier(class_name="CorrectAttack", class_module="pyrit.backend")
atomic_id = build_atomic_attack_identifier(attack_identifier=correct_attack_id)
ar = _make_attack_result(atomic_attack_identifier=atomic_id)
entry = AttackResultEntry(entry=ar)

# Simulate a stale attack_identifier column (as if it wasn't updated)
stale_id = ComponentIdentifier(class_name="StaleAttack", class_module="pyrit.backend")
entry.attack_identifier = stale_id.to_dict()

round_tripped = entry.get_attack_result()
strategy = round_tripped.get_attack_strategy_identifier()
assert strategy is not None
assert strategy.class_name == "CorrectAttack"


# ---------------------------------------------------------------------------
# ScenarioResultEntry
Expand Down
Loading