From 2b120100506e0c71e63267379adb53bbb31207f4 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 16 Apr 2026 10:19:12 -0700 Subject: [PATCH] Deprecate unused fields in message piece --- pyrit/memory/memory_interface.py | 6 ++ pyrit/models/message_piece.py | 29 ++++++++ .../test_interface_attack_results.py | 17 +++++ tests/unit/models/test_message_piece.py | 66 +++++++++++++++++++ 4 files changed, 118 insertions(+) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index a0abed147..9bf33031f 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -5,6 +5,7 @@ import atexit import logging import uuid +import warnings import weakref from collections.abc import MutableSequence, Sequence from contextlib import closing @@ -1531,6 +1532,11 @@ def get_attack_results( ) if targeted_harm_categories: + warnings.warn( + "The 'targeted_harm_categories' parameter is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) # Use database-specific JSON query method conditions.append( self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories) diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 91d01032b..2a58f9d0d 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -4,6 +4,7 @@ from __future__ import annotations import uuid +import warnings from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args from uuid import uuid4 @@ -15,6 +16,7 @@ from pyrit.models.score import Score Originator = Literal["attack", "converter", "undefined", "scorer"] +"""Deprecated: The Originator type alias will be removed in a future release.""" class MessagePiece: @@ -135,6 +137,12 @@ def __init__( ) # Handle scorer_identifier: normalize to ComponentIdentifier (handles dict with deprecation warning) + if scorer_identifier is not None: + warnings.warn( + "The 'scorer_identifier' parameter is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) self.scorer_identifier: Optional[ComponentIdentifier] = ( ComponentIdentifier.normalize(scorer_identifier) if scorer_identifier else None ) @@ -161,12 +169,33 @@ def __init__( raise ValueError(f"response_error {response_error} is not a valid response error.") self.response_error = response_error + + if originator != "undefined": + warnings.warn( + "The 'originator' parameter is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) self.originator = originator # Original prompt id defaults to id (assumes that this is the original prompt, not a duplicate) self.original_prompt_id = original_prompt_id or self.id + if scores is not None: + warnings.warn( + "The 'scores' parameter is deprecated and will be removed in a future release. " + "Scores are now hydrated by the memory layer.", + DeprecationWarning, + stacklevel=2, + ) self.scores = scores if scores else [] + + if targeted_harm_categories is not None: + warnings.warn( + "The 'targeted_harm_categories' parameter is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) self.targeted_harm_categories = targeted_harm_categories if targeted_harm_categories else [] async def set_sha256_values_async(self) -> None: diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 2e30ba368..6a3857c57 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -1414,3 +1414,20 @@ def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance ], ) assert len(results) == 0 + + +def test_get_attack_results_targeted_harm_categories_emits_deprecation_warning(sqlite_instance: MemoryInterface): + """Test that passing targeted_harm_categories emits a DeprecationWarning.""" + import warnings + + message_piece = create_message_piece("conv_1", 1, targeted_harm_categories=["violence"]) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) + + attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + sqlite_instance.get_attack_results(targeted_harm_categories=["violence"]) + deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert any("targeted_harm_categories" in str(m.message) for m in deprecation_msgs) diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index d13064a32..a536cdbb0 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -5,6 +5,7 @@ import tempfile import time import uuid +import warnings from collections.abc import MutableSequence from datetime import datetime, timedelta, timezone @@ -1043,3 +1044,68 @@ def test_role_setter_sets_simulated_assistant(self): assert piece.get_role_for_storage() == "simulated_assistant" assert piece.api_role == "assistant" assert piece.is_simulated is True + + +class TestMessagePieceDeprecationWarnings: + """Tests for deprecation warnings on parameters scheduled for removal.""" + + def test_scorer_identifier_emits_deprecation_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MessagePiece( + role="user", + original_value="Hello", + scorer_identifier=ComponentIdentifier(class_name="S", class_module="m"), + ) + deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert any("scorer_identifier" in str(m.message) for m in deprecation_msgs) + + def test_scorer_identifier_none_no_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MessagePiece(role="user", original_value="Hello") + deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert not any("scorer_identifier" in str(m.message) for m in deprecation_msgs) + + def test_originator_non_default_emits_deprecation_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MessagePiece(role="user", original_value="Hello", originator="attack") + deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert any("originator" in str(m.message) for m in deprecation_msgs) + + def test_originator_default_no_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MessagePiece(role="user", original_value="Hello") + deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert not any("originator" in str(m.message) for m in deprecation_msgs) + + def test_scores_emits_deprecation_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MessagePiece(role="user", original_value="Hello", scores=[]) + # scores=[] is falsy but not None, however the check is `scores is not None` + deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert any("scores" in str(m.message) for m in deprecation_msgs) + + def test_scores_none_no_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MessagePiece(role="user", original_value="Hello") + deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert not any("scores" in str(m.message) for m in deprecation_msgs) + + def test_targeted_harm_categories_emits_deprecation_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MessagePiece(role="user", original_value="Hello", targeted_harm_categories=["violence"]) + deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert any("targeted_harm_categories" in str(m.message) for m in deprecation_msgs) + + def test_targeted_harm_categories_none_no_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MessagePiece(role="user", original_value="Hello") + deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert not any("targeted_harm_categories" in str(m.message) for m in deprecation_msgs)