Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions pyrit/prompt_normalizer/prompt_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,12 @@ async def send_prompt_async(
attack_identifier (Optional[ComponentIdentifier], optional): Identifier for the attack. Defaults to
None.

Returns:
Message: The response received from the target.

Raises:
Exception: If an error occurs during the request processing.
ValueError: If the message pieces are not part of the same sequence.
EmptyResponseException: If the target returns no valid responses.

Returns:
Message: The response received from the target.
"""
# Validates that the MessagePieces in the Message are part of the same sequence
request_converter_configurations = request_converter_configurations or []
Expand Down Expand Up @@ -156,7 +155,15 @@ async def send_prompt_async(
await self._calc_hash(request=request)
self.memory.add_message_to_memory(request=request)
return request
raise EmptyResponseException(message="Target returned no valid responses")
empty_response = construct_response_from_request(
request=request.message_pieces[0],
response_text_pieces=[""],
response_type="text",
error="empty",
)
await self._calc_hash(request=empty_response)
self.memory.add_message_to_memory(request=empty_response)
return empty_response

# Process all response messages (targets return list[Message])
# Only apply response converters to the last message (final response)
Expand Down Expand Up @@ -210,7 +217,7 @@ async def send_prompt_batch_to_target_async(
"conversation_id",
]

responses = await batch_task_async(
return await batch_task_async(
prompt_target=target,
batch_size=batch_size,
items_to_batch=batch_items,
Expand All @@ -221,9 +228,6 @@ async def send_prompt_batch_to_target_async(
attack_identifier=attack_identifier,
)

# Filter out None responses (e.g., from empty responses)
return [response for response in responses if response is not None]

async def convert_values(
self,
converter_configurations: list[PromptConverterConfiguration],
Expand Down
101 changes: 74 additions & 27 deletions tests/unit/prompt_normalizer/test_prompt_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,23 @@ async def test_send_prompt_async_multiple_converters(mock_memory_instance, seed_


@pytest.mark.asyncio
async def test_send_prompt_async_no_response_raises_empty_response(mock_memory_instance, seed_group):
prompt_target = AsyncMock()
async def test_send_prompt_async_no_response_adds_memory(mock_memory_instance, seed_group):
prompt_target = MagicMock()
prompt_target.send_prompt_async = AsyncMock(return_value=None)
prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget")

normalizer = PromptNormalizer()
message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user")

with pytest.raises(EmptyResponseException):
await normalizer.send_prompt_async(message=message, target=prompt_target)

# Request should still be added to memory before the exception
assert mock_memory_instance.add_message_to_memory.call_count == 1
response = await normalizer.send_prompt_async(message=message, target=prompt_target)
assert mock_memory_instance.add_message_to_memory.call_count == 2

request = mock_memory_instance.add_message_to_memory.call_args[1]["request"]
assert_message_piece_hashes_set(request)
assert response.message_pieces[0].response_error == "empty"
assert response.message_pieces[0].original_value == ""
assert response.message_pieces[0].original_value_data_type == "text"
assert_message_piece_hashes_set(response)


@pytest.mark.asyncio
Expand Down Expand Up @@ -187,34 +189,29 @@ async def test_send_prompt_async_request_response_added_to_memory(mock_memory_in

@pytest.mark.asyncio
async def test_send_prompt_async_exception(mock_memory_instance, seed_group):
prompt_target = AsyncMock()
prompt_target = MagicMock()
prompt_target.send_prompt_async = AsyncMock(side_effect=ValueError("test_exception"))
prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget")

seed_prompt_value = seed_group.prompts[0].value

normalizer = PromptNormalizer()
message = Message.from_prompt(prompt=seed_prompt_value, role="user")

with patch("pyrit.models.construct_response_from_request") as mock_construct:
mock_construct.return_value = "test"
with pytest.raises(Exception, match="Error sending prompt with conversation ID"):
await normalizer.send_prompt_async(message=message, target=prompt_target)

try:
await normalizer.send_prompt_async(message=message, target=prompt_target)
except ValueError:
assert mock_memory_instance.add_message_to_memory.call_count == 2
assert mock_memory_instance.add_message_to_memory.call_count == 2

# Validate that first request is added to memory, then exception is added to memory
assert (
seed_prompt_value
== mock_memory_instance.add_message_to_memory.call_args_list[0][1]["request"]
.message_pieces[0]
.original_value
)
assert (
mock_memory_instance.add_message_to_memory.call_args_list[1][1]["request"]
.message_pieces[0]
.original_value
== "test_exception"
)
# Validate that first request is added to memory, then exception is added to memory
assert (
seed_prompt_value
== mock_memory_instance.add_message_to_memory.call_args_list[0][1]["request"].message_pieces[0].original_value
)
assert (
"test_exception"
in mock_memory_instance.add_message_to_memory.call_args_list[1][1]["request"].message_pieces[0].original_value
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -386,6 +383,56 @@ async def test_prompt_normalizer_send_prompt_batch_async_throws(
assert len(results) == 1


@pytest.mark.asyncio
async def test_prompt_normalizer_send_prompt_batch_async_preserves_empty_response_alignment(
mock_memory_instance,
):
prompt_target = MagicMock()
prompt_target._max_requests_per_minute = None
prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget")
prompt_target.send_prompt_async = AsyncMock(
side_effect=[
[MessagePiece(role="assistant", original_value="response 1", conversation_id="conv-1").to_message()],
None,
]
)

normalizer = PromptNormalizer()
requests = [
NormalizerRequest(
message=Message.from_prompt(prompt="prompt 1", role="user"),
conversation_id="conv-1",
),
NormalizerRequest(
message=Message.from_prompt(prompt="prompt 2", role="user"),
conversation_id="conv-2",
),
]

results = await normalizer.send_prompt_batch_to_target_async(requests=requests, target=prompt_target, batch_size=2)

assert len(results) == 2
assert results[0].message_pieces[0].original_value == "response 1"
assert results[1].message_pieces[0].response_error == "empty"
assert results[1].message_pieces[0].original_value == ""
assert results[1].message_pieces[0].conversation_id == "conv-2"


@pytest.mark.asyncio
async def test_send_prompt_async_none_in_list_response_returns_empty(mock_memory_instance, seed_group):
"""Target returning [None] (list containing None) should produce an empty response."""
prompt_target = MagicMock()
prompt_target.send_prompt_async = AsyncMock(return_value=[None])
prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget")

normalizer = PromptNormalizer()
message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user")

response = await normalizer.send_prompt_async(message=message, target=prompt_target)
assert response.message_pieces[0].response_error == "empty"
assert response.message_pieces[0].original_value == ""


@pytest.mark.asyncio
async def test_build_message(mock_memory_instance, seed_group):
# This test is obsolete since _build_message was removed and message preparation
Expand Down
Loading