From 0e4b1047a1e179327246ee5d84288e0d61c24aae Mon Sep 17 00:00:00 2001 From: biefan Date: Tue, 17 Mar 2026 15:40:22 +0800 Subject: [PATCH] Preserve empty responses in prompt normalizer batches - Return explicit empty response Message (error='empty') from send_prompt_async() instead of None when a target returns no messages - Preserve write-only target (empty list) path returning request as-is - Remove None filter in send_prompt_batch_to_target_async() since send_prompt_async() now always returns a valid Message - Add regression tests for batch response alignment and [None] edge case - Fix docstring section ordering (Returns before Raises) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/prompt_normalizer/prompt_normalizer.py | 22 ++-- .../test_prompt_normalizer.py | 101 +++++++++++++----- 2 files changed, 87 insertions(+), 36 deletions(-) diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 7407cd849..ce5452eda 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -83,13 +83,12 @@ async def send_prompt_async( attack_identifier (Optional[ComponentIdentifier], optional): Identifier for the attack. Defaults to None. + Returns: + Message: The response received from the target. + Raises: Exception: If an error occurs during the request processing. ValueError: If the message pieces are not part of the same sequence. - EmptyResponseException: If the target returns no valid responses. - - Returns: - Message: The response received from the target. """ # Validates that the MessagePieces in the Message are part of the same sequence request_converter_configurations = request_converter_configurations or [] @@ -156,7 +155,15 @@ async def send_prompt_async( await self._calc_hash(request=request) self.memory.add_message_to_memory(request=request) return request - raise EmptyResponseException(message="Target returned no valid responses") + empty_response = construct_response_from_request( + request=request.message_pieces[0], + response_text_pieces=[""], + response_type="text", + error="empty", + ) + await self._calc_hash(request=empty_response) + self.memory.add_message_to_memory(request=empty_response) + return empty_response # Process all response messages (targets return list[Message]) # Only apply response converters to the last message (final response) @@ -210,7 +217,7 @@ async def send_prompt_batch_to_target_async( "conversation_id", ] - responses = await batch_task_async( + return await batch_task_async( prompt_target=target, batch_size=batch_size, items_to_batch=batch_items, @@ -221,9 +228,6 @@ async def send_prompt_batch_to_target_async( attack_identifier=attack_identifier, ) - # Filter out None responses (e.g., from empty responses) - return [response for response in responses if response is not None] - async def convert_values( self, converter_configurations: list[PromptConverterConfiguration], diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index 66662886f..bc10d3323 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -117,21 +117,23 @@ async def test_send_prompt_async_multiple_converters(mock_memory_instance, seed_ @pytest.mark.asyncio -async def test_send_prompt_async_no_response_raises_empty_response(mock_memory_instance, seed_group): - prompt_target = AsyncMock() +async def test_send_prompt_async_no_response_adds_memory(mock_memory_instance, seed_group): + prompt_target = MagicMock() prompt_target.send_prompt_async = AsyncMock(return_value=None) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") - with pytest.raises(EmptyResponseException): - await normalizer.send_prompt_async(message=message, target=prompt_target) - - # Request should still be added to memory before the exception - assert mock_memory_instance.add_message_to_memory.call_count == 1 + response = await normalizer.send_prompt_async(message=message, target=prompt_target) + assert mock_memory_instance.add_message_to_memory.call_count == 2 request = mock_memory_instance.add_message_to_memory.call_args[1]["request"] assert_message_piece_hashes_set(request) + assert response.message_pieces[0].response_error == "empty" + assert response.message_pieces[0].original_value == "" + assert response.message_pieces[0].original_value_data_type == "text" + assert_message_piece_hashes_set(response) @pytest.mark.asyncio @@ -187,34 +189,29 @@ async def test_send_prompt_async_request_response_added_to_memory(mock_memory_in @pytest.mark.asyncio async def test_send_prompt_async_exception(mock_memory_instance, seed_group): - prompt_target = AsyncMock() + prompt_target = MagicMock() + prompt_target.send_prompt_async = AsyncMock(side_effect=ValueError("test_exception")) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") seed_prompt_value = seed_group.prompts[0].value normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_prompt_value, role="user") - with patch("pyrit.models.construct_response_from_request") as mock_construct: - mock_construct.return_value = "test" + with pytest.raises(Exception, match="Error sending prompt with conversation ID"): + await normalizer.send_prompt_async(message=message, target=prompt_target) - try: - await normalizer.send_prompt_async(message=message, target=prompt_target) - except ValueError: - assert mock_memory_instance.add_message_to_memory.call_count == 2 + assert mock_memory_instance.add_message_to_memory.call_count == 2 - # Validate that first request is added to memory, then exception is added to memory - assert ( - seed_prompt_value - == mock_memory_instance.add_message_to_memory.call_args_list[0][1]["request"] - .message_pieces[0] - .original_value - ) - assert ( - mock_memory_instance.add_message_to_memory.call_args_list[1][1]["request"] - .message_pieces[0] - .original_value - == "test_exception" - ) + # Validate that first request is added to memory, then exception is added to memory + assert ( + seed_prompt_value + == mock_memory_instance.add_message_to_memory.call_args_list[0][1]["request"].message_pieces[0].original_value + ) + assert ( + "test_exception" + in mock_memory_instance.add_message_to_memory.call_args_list[1][1]["request"].message_pieces[0].original_value + ) @pytest.mark.asyncio @@ -386,6 +383,56 @@ async def test_prompt_normalizer_send_prompt_batch_async_throws( assert len(results) == 1 +@pytest.mark.asyncio +async def test_prompt_normalizer_send_prompt_batch_async_preserves_empty_response_alignment( + mock_memory_instance, +): + prompt_target = MagicMock() + prompt_target._max_requests_per_minute = None + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") + prompt_target.send_prompt_async = AsyncMock( + side_effect=[ + [MessagePiece(role="assistant", original_value="response 1", conversation_id="conv-1").to_message()], + None, + ] + ) + + normalizer = PromptNormalizer() + requests = [ + NormalizerRequest( + message=Message.from_prompt(prompt="prompt 1", role="user"), + conversation_id="conv-1", + ), + NormalizerRequest( + message=Message.from_prompt(prompt="prompt 2", role="user"), + conversation_id="conv-2", + ), + ] + + results = await normalizer.send_prompt_batch_to_target_async(requests=requests, target=prompt_target, batch_size=2) + + assert len(results) == 2 + assert results[0].message_pieces[0].original_value == "response 1" + assert results[1].message_pieces[0].response_error == "empty" + assert results[1].message_pieces[0].original_value == "" + assert results[1].message_pieces[0].conversation_id == "conv-2" + + +@pytest.mark.asyncio +async def test_send_prompt_async_none_in_list_response_returns_empty(mock_memory_instance, seed_group): + """Target returning [None] (list containing None) should produce an empty response.""" + prompt_target = MagicMock() + prompt_target.send_prompt_async = AsyncMock(return_value=[None]) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") + + normalizer = PromptNormalizer() + message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") + + response = await normalizer.send_prompt_async(message=message, target=prompt_target) + assert response.message_pieces[0].response_error == "empty" + assert response.message_pieces[0].original_value == "" + + @pytest.mark.asyncio async def test_build_message(mock_memory_instance, seed_group): # This test is obsolete since _build_message was removed and message preparation