From f1186d9873b657a7cb095f0a6e43987aab47bbb2 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Mon, 20 Apr 2026 20:30:36 -0400 Subject: [PATCH 1/8] remove promptchattarget ref in converters & add native target requirements & migrate towards using target config from target capabilities and misc checks --- .../attack/component/conversation_manager.py | 8 ++- .../attack/multi_turn/chunked_request.py | 16 ++--- pyrit/executor/attack/multi_turn/crescendo.py | 20 +++---- .../attack/multi_turn/multi_prompt_sending.py | 17 +++--- .../attack/multi_turn/tree_of_attacks.py | 11 +++- pyrit/prompt_converter/denylist_converter.py | 6 +- .../llm_generic_text_converter.py | 11 ++-- .../malicious_question_generator_converter.py | 6 +- .../prompt_converter/math_prompt_converter.py | 6 +- pyrit/prompt_converter/noise_converter.py | 6 +- .../prompt_converter/persuasion_converter.py | 7 ++- .../random_translation_converter.py | 6 +- .../scientific_translation_converter.py | 6 +- pyrit/prompt_converter/tense_converter.py | 6 +- pyrit/prompt_converter/tone_converter.py | 6 +- .../toxic_sentence_generator_converter.py | 6 +- .../prompt_converter/translation_converter.py | 7 ++- pyrit/prompt_converter/variation_converter.py | 7 ++- pyrit/prompt_target/__init__.py | 6 +- .../common/prompt_chat_target.py | 26 +++------ pyrit/prompt_target/common/prompt_target.py | 52 ++++++++++++++++- .../common/target_requirements.py | 58 ++++++++++++++++--- pyrit/scenario/scenarios/airt/psychosocial.py | 6 +- pyrit/score/float_scale/float_scale_scorer.py | 4 +- .../score/float_scale/insecure_code_scorer.py | 5 +- .../self_ask_general_float_scale_scorer.py | 4 +- .../float_scale/self_ask_likert_scorer.py | 5 +- .../float_scale/self_ask_scale_scorer.py | 5 +- pyrit/score/scorer.py | 2 +- pyrit/score/true_false/gandalf_scorer.py | 7 ++- .../true_false/self_ask_category_scorer.py | 5 +- .../self_ask_general_true_false_scorer.py | 4 +- .../self_ask_question_answer_scorer.py | 4 +- .../true_false/self_ask_refusal_scorer.py | 5 +- .../true_false/self_ask_true_false_scorer.py | 5 +- .../component/test_conversation_manager.py | 6 ++ .../test_supports_multi_turn_attacks.py | 48 ++++++--------- .../target/test_target_requirements.py | 47 +++++++++++++++ 38 files changed, 308 insertions(+), 154 deletions(-) diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 7a27cb5666..8556125318 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -21,6 +21,7 @@ from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName if TYPE_CHECKING: from pyrit.executor.attack.core import AttackContext @@ -321,8 +322,11 @@ async def initialize_context_async( logger.debug(f"No prepended conversation for context initialization: {conversation_id}") return state - # Handle target type compatibility - is_chat_target = isinstance(target, PromptChatTarget) + # Targets that don't natively support multi-turn history cannot consume a + # prepended multi-message conversation as-is — route them to the + # single-string fallback path. Type identity (PromptChatTarget) is a + # legacy signal for this; capability-based routing is the durable form. + is_chat_target = target.configuration.includes(capability=CapabilityName.MULTI_TURN) if not is_chat_target: return await self._handle_non_chat_target_async( context=context, diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 1a70c89195..d22166f644 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -29,6 +29,8 @@ ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from pyrit.score import TrueFalseScorer @@ -141,6 +143,12 @@ def __init__( params_type=ChunkedRequestAttackParameters, ) + # Chunked request issues multiple distinct turns; history-squash + # adaptation would collapse them into a single prompt. + TargetRequirements( + required_native_capabilities=frozenset({CapabilityName.MULTI_TURN}), + ).validate(configuration=objective_target.configuration) + # Store chunk configuration self._chunk_size = chunk_size self._total_length = total_length @@ -226,15 +234,7 @@ async def _setup_async(self, *, context: ChunkedRequestAttackContext) -> None: Args: context (ChunkedRequestAttackContext): The attack context containing attack parameters. - - Raises: - ValueError: If the objective target does not support multi-turn conversations. """ - if not self._objective_target.capabilities.supports_multi_turn: - raise ValueError( - "ChunkedRequestAttack requires a multi-turn target. " - "The objective target does not support multi-turn conversations." - ) # Ensure the context has a session context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 4a180d5df3..7ded67b3ca 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -44,6 +44,8 @@ ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import ( FloatScaleThresholdScorer, Scorer, @@ -148,11 +150,18 @@ def __init__( application by role, message normalization, and non-chat target behavior. Raises: - ValueError: If objective_target is not a PromptChatTarget. + ValueError: If objective_target is not a PromptChatTarget, or does not + natively support multi-turn conversations. """ # Initialize base class super().__init__(objective_target=objective_target, logger=logger, context_type=CrescendoAttackContext) + # Crescendo fundamentally relies on multi-turn conversation history to + # gradually escalate prompts; history-squash adaptation would defeat it. + TargetRequirements( + required_native_capabilities=frozenset({CapabilityName.MULTI_TURN}), + ).validate(configuration=objective_target.configuration) + self._memory = CentralMemory.get_memory_instance() # Initialize converter configuration @@ -257,16 +266,7 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None: Args: context (CrescendoAttackContext): Attack context with configuration - - Raises: - ValueError: If the objective target does not support multi-turn conversations. """ - if not self._objective_target.capabilities.supports_multi_turn: - raise ValueError( - "CrescendoAttack requires a multi-turn target. Crescendo fundamentally relies on " - "multi-turn conversation history to gradually escalate prompts. " - "Use RedTeamingAttack or TreeOfAttacksWithPruning instead." - ) # Ensure the context has a session context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index a9d4b75adc..40158b9e4d 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -29,6 +29,8 @@ ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import Scorer if TYPE_CHECKING: @@ -152,6 +154,12 @@ def __init__( params_type=MultiPromptSendingAttackParameters, ) + # Sending a sequence of prompts requires a real multi-turn target; + # history-squash adaptation would collapse them into one message. + TargetRequirements( + required_native_capabilities=frozenset({CapabilityName.MULTI_TURN}), + ).validate(configuration=objective_target.configuration) + # Initialize the converter configuration attack_converter_config = attack_converter_config or AttackConverterConfig() self._request_converters = attack_converter_config.request_converters @@ -204,16 +212,7 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: Args: context (MultiTurnAttackContext): The attack context containing attack parameters. - - Raises: - ValueError: If the objective target does not support multi-turn conversations. """ - if not self._objective_target.capabilities.supports_multi_turn: - raise ValueError( - "MultiPromptSendingAttack requires a multi-turn target. " - "The objective target does not support multi-turn conversations." - ) - # Ensure the context has a session (like red_teaming.py does) context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index e92bd1cf67..653a90c773 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -51,6 +51,8 @@ ) from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import ( FloatScaleThresholdScorer, Scorer, @@ -1322,8 +1324,13 @@ def __init__( # Initialize adversarial configuration self._adversarial_chat = attack_adversarial_config.target - if not isinstance(self._adversarial_chat, PromptChatTarget): - raise ValueError("The adversarial target must be a PromptChatTarget for TAP attack.") + # TAP sets a system prompt on the adversarial target and drives a + # multi-turn dialogue through it; both capabilities must be native. + TargetRequirements( + required_native_capabilities=frozenset( + {CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT} + ), + ).validate(configuration=self._adversarial_chat.configuration) # Load system prompts self._adversarial_chat_system_prompt_path = ( diff --git a/pyrit/prompt_converter/denylist_converter.py b/pyrit/prompt_converter/denylist_converter.py index a9672e3718..9cc21ba653 100644 --- a/pyrit/prompt_converter/denylist_converter.py +++ b/pyrit/prompt_converter/denylist_converter.py @@ -10,7 +10,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ class DenylistConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, denylist: list[str] = None, ): @@ -34,7 +34,7 @@ def __init__( Initialize the converter with a target, an optional system prompt template, and a denylist. Args: - converter_target (PromptChatTarget): The target for the prompt conversion. + converter_target (PromptTarget): The target for the prompt conversion. Can be omitted if a default has been configured via PyRIT initialization. system_prompt_template (Optional[SeedPrompt]): The system prompt template to use for the conversion. If not provided, a default template will be used. diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index f56990247f..b8b22f127b 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -15,7 +15,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ class LLMGenericTextConverter(PromptConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, user_prompt_template_with_objective: Optional[SeedPrompt] = None, **kwargs: Any, @@ -41,8 +41,10 @@ def __init__( Initialize the converter with a target and optional prompt templates. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. - Can be omitted if a default has been configured via PyRIT initialization. + converter_target (PromptTarget): The endpoint that converts the prompt. Must satisfy + ``CHAT_CONSUMER_REQUIREMENTS`` (system-prompt + multi-turn capabilities, possibly via + normalization-pipeline adaptation). Can be omitted if a default has been configured + via PyRIT initialization. system_prompt_template (SeedPrompt, Optional): The prompt template to set as the system prompt. user_prompt_template_with_objective (SeedPrompt, Optional): The prompt template to set as the user prompt. expects @@ -51,6 +53,7 @@ def __init__( Raises: ValueError: If converter_target is not provided and no default has been configured. """ + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=converter_target.configuration) self._converter_target = converter_target self._system_prompt_template = system_prompt_template self._prompt_kwargs = kwargs diff --git a/pyrit/prompt_converter/malicious_question_generator_converter.py b/pyrit/prompt_converter/malicious_question_generator_converter.py index 41a7848458..fb9e225261 100644 --- a/pyrit/prompt_converter/malicious_question_generator_converter.py +++ b/pyrit/prompt_converter/malicious_question_generator_converter.py @@ -10,7 +10,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -26,14 +26,14 @@ class MaliciousQuestionGeneratorConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with a specific target and template. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. """ diff --git a/pyrit/prompt_converter/math_prompt_converter.py b/pyrit/prompt_converter/math_prompt_converter.py index fd6491bbc1..8d809a6661 100644 --- a/pyrit/prompt_converter/math_prompt_converter.py +++ b/pyrit/prompt_converter/math_prompt_converter.py @@ -10,7 +10,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -26,14 +26,14 @@ class MathPromptConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with a specific target and template. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. """ diff --git a/pyrit/prompt_converter/noise_converter.py b/pyrit/prompt_converter/noise_converter.py index 0d7bdf302f..a89e2d85c7 100644 --- a/pyrit/prompt_converter/noise_converter.py +++ b/pyrit/prompt_converter/noise_converter.py @@ -11,7 +11,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ class NoiseConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] noise: Optional[str] = None, number_errors: int = 5, prompt_template: Optional[SeedPrompt] = None, @@ -36,7 +36,7 @@ def __init__( Initialize the converter with the specified parameters. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. noise (str): The noise to inject. Grammar error, delete random letter, insert random space, etc. number_errors (int): The number of errors to inject. diff --git a/pyrit/prompt_converter/persuasion_converter.py b/pyrit/prompt_converter/persuasion_converter.py index 11b6bd66e6..5ff7b51c9d 100644 --- a/pyrit/prompt_converter/persuasion_converter.py +++ b/pyrit/prompt_converter/persuasion_converter.py @@ -21,7 +21,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -52,14 +52,14 @@ class PersuasionConverter(PromptConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] persuasion_technique: str, ): """ Initialize the converter with the specified target and prompt template. Args: - converter_target (PromptChatTarget): The chat target used to perform rewriting on user prompts. + converter_target (PromptTarget): The chat target used to perform rewriting on user prompts. Can be omitted if a default has been configured via PyRIT initialization. persuasion_technique (str): Persuasion technique to be used by the converter, determines the system prompt to be used to generate new prompts. Must be one of "authority_endorsement", "evidence_based", @@ -69,6 +69,7 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the persuasion technique is not supported or does not exist. """ + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=converter_target.configuration) self.converter_target = converter_target try: diff --git a/pyrit/prompt_converter/random_translation_converter.py b/pyrit/prompt_converter/random_translation_converter.py index 74953c2603..4711e0e0d5 100644 --- a/pyrit/prompt_converter/random_translation_converter.py +++ b/pyrit/prompt_converter/random_translation_converter.py @@ -13,7 +13,7 @@ from pyrit.prompt_converter.prompt_converter import ConverterResult from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy from pyrit.prompt_converter.word_level_converter import WordLevelConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -35,7 +35,7 @@ class RandomTranslationConverter(LLMGenericTextConverter, WordLevelConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, languages: Optional[list[str]] = None, word_selection_strategy: Optional[WordSelectionStrategy] = None, @@ -44,7 +44,7 @@ def __init__( Initialize the converter with a target, an optional system prompt template, and language options. Args: - converter_target (PromptChatTarget): The target for the prompt conversion. + converter_target (PromptTarget): The target for the prompt conversion. Can be omitted if a default has been configured via PyRIT initialization. system_prompt_template (Optional[SeedPrompt]): The system prompt template to use for the conversion. If not provided, a default template will be used. diff --git a/pyrit/prompt_converter/scientific_translation_converter.py b/pyrit/prompt_converter/scientific_translation_converter.py index 2a6c965996..bdc7987041 100644 --- a/pyrit/prompt_converter/scientific_translation_converter.py +++ b/pyrit/prompt_converter/scientific_translation_converter.py @@ -10,7 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ class ScientificTranslationConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] mode: str = "combined", prompt_template: Optional[SeedPrompt] = None, ) -> None: @@ -53,7 +53,7 @@ def __init__( Initialize the scientific translation converter. Args: - converter_target (PromptChatTarget): The LLM target to perform the conversion. + converter_target (PromptTarget): The LLM target to perform the conversion. mode (str): The translation mode to use. Built-in options are: - ``academic``: Use academic/homework style framing diff --git a/pyrit/prompt_converter/tense_converter.py b/pyrit/prompt_converter/tense_converter.py index 237a2934d5..66a64d0158 100644 --- a/pyrit/prompt_converter/tense_converter.py +++ b/pyrit/prompt_converter/tense_converter.py @@ -10,7 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ class TenseConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] tense: str, prompt_template: Optional[SeedPrompt] = None, ): @@ -34,7 +34,7 @@ def __init__( Initialize the converter with the target chat support, tense, and optional prompt template. Args: - converter_target (PromptChatTarget): The target chat support for the conversion which will translate. + converter_target (PromptTarget): The target chat support for the conversion which will translate. Can be omitted if a default has been configured via PyRIT initialization. tense (str): The tense the converter should convert the prompt to. E.g. past, present, future. prompt_template (SeedPrompt, Optional): The prompt template for the conversion. diff --git a/pyrit/prompt_converter/tone_converter.py b/pyrit/prompt_converter/tone_converter.py index a7b8e5a9f1..69ee355aaf 100644 --- a/pyrit/prompt_converter/tone_converter.py +++ b/pyrit/prompt_converter/tone_converter.py @@ -10,7 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ class ToneConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] tone: str, prompt_template: Optional[SeedPrompt] = None, ): @@ -34,7 +34,7 @@ def __init__( Initialize the converter with the target chat support, tone, and optional prompt template. Args: - converter_target (PromptChatTarget): The target chat support for the conversion which will translate. + converter_target (PromptTarget): The target chat support for the conversion which will translate. Can be omitted if a default has been configured via PyRIT initialization. tone (str): The tone for the conversation. E.g. upset, sarcastic, indifferent, etc. prompt_template (SeedPrompt, Optional): The prompt template for the conversion. diff --git a/pyrit/prompt_converter/toxic_sentence_generator_converter.py b/pyrit/prompt_converter/toxic_sentence_generator_converter.py index d3390c6af7..67922b4c03 100644 --- a/pyrit/prompt_converter/toxic_sentence_generator_converter.py +++ b/pyrit/prompt_converter/toxic_sentence_generator_converter.py @@ -14,7 +14,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -34,14 +34,14 @@ class ToxicSentenceGeneratorConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with a specific target and template. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. If not provided, defaults to the ``toxic_sentence_generator.yaml``. diff --git a/pyrit/prompt_converter/translation_converter.py b/pyrit/prompt_converter/translation_converter.py index 911f72ab57..62ae0ad4c8 100644 --- a/pyrit/prompt_converter/translation_converter.py +++ b/pyrit/prompt_converter/translation_converter.py @@ -24,7 +24,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ class TranslationConverter(PromptConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] language: str, prompt_template: Optional[SeedPrompt] = None, max_retries: int = 3, @@ -51,7 +51,7 @@ def __init__( Initialize the converter with the target chat support, language, and optional prompt template. Args: - converter_target (PromptChatTarget): The target chat support for the conversion which will translate. + converter_target (PromptTarget): The target chat support for the conversion which will translate. Can be omitted if a default has been configured via PyRIT initialization. language (str): The language for the conversion. E.g. Spanish, French, leetspeak, etc. prompt_template (SeedPrompt, Optional): The prompt template for the conversion. @@ -62,6 +62,7 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the language is not provided. """ + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=converter_target.configuration) self.converter_target = converter_target # Retry strategy for the conversion diff --git a/pyrit/prompt_converter/variation_converter.py b/pyrit/prompt_converter/variation_converter.py index 328e463072..8af7438202 100644 --- a/pyrit/prompt_converter/variation_converter.py +++ b/pyrit/prompt_converter/variation_converter.py @@ -23,7 +23,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -40,14 +40,14 @@ class VariationConverter(PromptConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with the specified target and prompt template. Args: - converter_target (PromptChatTarget): The target to which the prompt will be sent for conversion. + converter_target (PromptTarget): The target to which the prompt will be sent for conversion. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt, optional): The template used for generating the system prompt. If not provided, a default template will be used. @@ -55,6 +55,7 @@ def __init__( Raises: ValueError: If converter_target is not provided and no default has been configured. """ + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=converter_target.configuration) self.converter_target = converter_target # set to default strategy if not provided diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index c71dca4089..db24087d22 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -20,7 +20,10 @@ UnsupportedCapabilityBehavior, ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration -from pyrit.prompt_target.common.target_requirements import TargetRequirements +from pyrit.prompt_target.common.target_requirements import ( + CHAT_CONSUMER_REQUIREMENTS, + TargetRequirements, +) from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget from pyrit.prompt_target.http_target.http_target import HTTPTarget @@ -51,6 +54,7 @@ "AzureMLChatTarget", "CapabilityName", "CapabilityHandlingPolicy", + "CHAT_CONSUMER_REQUIREMENTS", "CopilotType", "ConversationNormalizationPipeline", "GandalfLevel", diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index ce1f254678..a84488c25f 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -73,26 +73,16 @@ def set_system_prompt( labels: Optional[dict[str, str]] = None, ) -> None: """ - Set the system prompt for the prompt target. May be overridden by subclasses. + Deprecated shim. Use :meth:`PromptTarget.set_system_prompt` on the base class. - Raises: - RuntimeError: If the conversation already exists. + Retained on ``PromptChatTarget`` so subclasses that override this method + continue to work. Delegates to the base-class implementation. """ - messages = self._memory.get_conversation(conversation_id=conversation_id) - - if messages: - raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") - - self._memory.add_message_to_memory( - request=MessagePiece( - role="system", - conversation_id=conversation_id, - original_value=system_prompt, - converted_value=system_prompt, - prompt_target_identifier=self.get_identifier(), - attack_identifier=attack_identifier, - labels=labels, - ).to_message() + super().set_system_prompt( + system_prompt=system_prompt, + conversation_id=conversation_id, + attack_identifier=attack_identifier, + labels=labels, ) def is_response_format_json(self, message_piece: MessagePiece) -> bool: diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 53d7d2085a..de8ef64b35 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -8,7 +8,7 @@ from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.memory import CentralMemory, MemoryInterface -from pyrit.models import Message +from pyrit.models import Message, MessagePiece from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration, resolve_configuration_compat @@ -270,6 +270,56 @@ def set_model_name(self, *, model_name: str) -> None: """ self._model_name = model_name + def set_system_prompt( + self, + *, + system_prompt: str, + conversation_id: str, + attack_identifier: ComponentIdentifier | None = None, + labels: dict[str, str] | None = None, + ) -> None: + """ + Inject a system prompt into memory for the given conversation. + + Writes a ``system``-role message so the target's normalization pipeline + (or the target itself, when it natively supports system prompts) will + pick it up on the next ``send_prompt_async`` call. + + If the target does not natively support system prompts, whether this + call is ultimately honored depends on the target's + :class:`CapabilityHandlingPolicy`: + + * ``ADAPT`` — the normalization pipeline (e.g. system squash) will + fold the system message into user content on the wire. + * ``RAISE`` — the first send after the system prompt is set will + raise, because the pipeline cannot adapt the missing capability. + + Args: + system_prompt (str): The system prompt text to set. + conversation_id (str): The conversation id to attach the prompt to. + attack_identifier (ComponentIdentifier | None): Optional attack identifier. + labels (dict[str, str] | None): Optional labels. + + Raises: + RuntimeError: If the conversation already has messages. + """ + messages = self._memory.get_conversation(conversation_id=conversation_id) + + if messages: + raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") + + self._memory.add_message_to_memory( + request=MessagePiece( + role="system", + conversation_id=conversation_id, + original_value=system_prompt, + converted_value=system_prompt, + prompt_target_identifier=self.get_identifier(), + attack_identifier=attack_identifier, + labels=labels, + ).to_message() + ) + def dispose_db_engine(self) -> None: """ Dispose database engine to release database connections and resources. diff --git a/pyrit/prompt_target/common/target_requirements.py b/pyrit/prompt_target/common/target_requirements.py index 95182b47b5..cde3640690 100644 --- a/pyrit/prompt_target/common/target_requirements.py +++ b/pyrit/prompt_target/common/target_requirements.py @@ -20,26 +20,44 @@ class TargetRequirements: Consumers define their requirements once and validate them against a ``TargetConfiguration`` at construction time. This replaces ad-hoc ``isinstance`` checks and scattered capability branching. + + Two levels of requirement are supported: + + * ``required_capabilities`` — the target must *handle* the capability, + either natively or via an ``ADAPT`` policy (normalization pipeline). + Use this when the consumer only cares that the behavior is available + on the wire, regardless of how. + * ``required_native_capabilities`` — the target must support the + capability natively; adaptation via the normalization pipeline is + not acceptable. Use this when adaptation would defeat the consumer's + purpose (e.g. a multi-turn attack cannot run against a target whose + history is squashed into a single prompt). """ - # The set of capabilities the consumer requires. + # Capabilities the consumer requires, native or adapted. required_capabilities: frozenset[CapabilityName] = field(default_factory=frozenset) + # Capabilities the consumer requires to be natively supported. Adaptation + # via the normalization pipeline is not acceptable for these capabilities. + required_native_capabilities: frozenset[CapabilityName] = field(default_factory=frozenset) + def validate(self, *, configuration: TargetConfiguration) -> None: """ Validate that the target configuration can satisfy all requirements. - Iterates over every required capability and delegates to - ``TargetConfiguration.ensure_can_handle``, which checks native support - first and then consults the handling policy. All violations are - collected and reported in a single ``ValueError``. + For ``required_capabilities`` this delegates to + ``TargetConfiguration.ensure_can_handle``, which accepts either + native support or an ``ADAPT`` policy. For + ``required_native_capabilities`` this checks + ``TargetConfiguration.includes`` directly — adaptation is not + acceptable. All violations are collected and reported in a single + ``ValueError``. Args: configuration (TargetConfiguration): The target configuration to validate against. Raises: - ValueError: If any required capability is missing and the policy - does not allow adaptation. + ValueError: If any required capability cannot be satisfied. """ errors: list[str] = [] for capability in sorted(self.required_capabilities, key=lambda c: c.value): @@ -47,8 +65,34 @@ def validate(self, *, configuration: TargetConfiguration) -> None: configuration.ensure_can_handle(capability=capability) except ValueError as exc: errors.append(str(exc)) + for capability in sorted(self.required_native_capabilities, key=lambda c: c.value): + if not configuration.includes(capability=capability): + errors.append( + f"Target does not natively support '{capability.value}' " + "and adaptation is not acceptable for this consumer." + ) if errors: raise ValueError( f"Target does not satisfy {len(errors)} required capability(ies):\n" + "\n".join(f" - {e}" for e in errors) ) + + +def _build_chat_consumer_requirements() -> TargetRequirements: + # Imported lazily to avoid a hard import cycle with target_capabilities at + # module load time (target_requirements only type-checks CapabilityName). + from pyrit.prompt_target.common.target_capabilities import CapabilityName + + return TargetRequirements( + required_capabilities=frozenset( + {CapabilityName.SYSTEM_PROMPT, CapabilityName.MULTI_TURN} + ), + ) + + +# Requirements declared by code paths that historically demanded a +# ``PromptChatTarget`` (converters and scorers that call ``set_system_prompt`` +# and then send a short conversation). Adaptation via the normalization +# pipeline is acceptable here — the consumer only needs the *behavior*, not +# native support. +CHAT_CONSUMER_REQUIREMENTS: TargetRequirements = _build_chat_consumer_requirements() diff --git a/pyrit/scenario/scenarios/airt/psychosocial.py b/pyrit/scenario/scenarios/airt/psychosocial.py index 9d201baf9d..1dd9f389e1 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial.py +++ b/pyrit/scenario/scenarios/airt/psychosocial.py @@ -28,6 +28,7 @@ PromptConverterConfiguration, ) from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.dataset_configuration import DatasetConfiguration @@ -421,9 +422,10 @@ def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScore async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: if self._objective_target is None: raise ValueError("objective_target must be set before creating attacks") - if not isinstance(self._objective_target, PromptChatTarget): + if not self._objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): raise TypeError( - f"PsychosocialHarmsScenario requires a PromptChatTarget, got {type(self._objective_target).__name__}" + f"PsychosocialHarmsScenario requires a target that natively supports " + f"multi-turn conversations, got {type(self._objective_target).__name__}." ) resolved = self._resolve_seed_groups() self._seed_groups = resolved.seed_groups diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index af39cf5bec..a22d29921c 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -7,7 +7,7 @@ from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.identifiers import ComponentIdentifier from pyrit.models import PromptDataType, Score, UnvalidatedScore -from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.score.scorer import Scorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -66,7 +66,7 @@ def get_scorer_metrics(self) -> Optional["HarmScorerMetrics"]: async def _score_value_with_llm( self, *, - prompt_target: PromptChatTarget, + prompt_target: PromptTarget, system_prompt: str, message_value: str, message_data_type: PromptDataType, diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index 45c64dab00..85fdfe86e6 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -9,7 +9,7 @@ from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -25,7 +25,7 @@ class InsecureCodeScorer(FloatScaleScorer): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, system_prompt_path: Optional[Union[str, Path]] = None, validator: Optional[ScorerPromptValidator] = None, ): @@ -40,6 +40,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) self._prompt_target = chat_target if not system_prompt_path: diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index ae9e0acc4b..b2796106ff 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, UnvalidatedScore - from pyrit.prompt_target import PromptChatTarget + from pyrit.prompt_target import PromptTarget class SelfAskGeneralFloatScaleScorer(FloatScaleScorer): @@ -28,7 +28,7 @@ class SelfAskGeneralFloatScaleScorer(FloatScaleScorer): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, system_prompt_format_string: str, prompt_format_string: Optional[str] = None, category: Optional[str] = None, diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index c6762089b6..60a87ade05 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -12,7 +12,7 @@ from pyrit.common.path import HARM_DEFINITION_PATH, SCORER_LIKERT_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -177,7 +177,7 @@ class SelfAskLikertScorer(FloatScaleScorer): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, likert_scale: Optional[LikertScalePaths] = None, custom_likert_path: Optional[Path] = None, custom_system_prompt_path: Optional[Path] = None, @@ -203,6 +203,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) self._prompt_target = chat_target self._likert_scale = likert_scale diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 4bf0dc2dee..d986293354 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -11,7 +11,7 @@ from pyrit.common.path import SCORER_SCALES_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -43,7 +43,7 @@ class SystemPaths(enum.Enum): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, scale_arguments_path: Optional[Union[Path, str]] = None, system_prompt_path: Optional[Union[Path, str]] = None, validator: Optional[ScorerPromptValidator] = None, @@ -61,6 +61,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) self._prompt_target = chat_target if not system_prompt_path: diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index b18a1802a9..b37ef8dd94 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -494,7 +494,7 @@ def scale_value_float(self, value: float, min_value: float, max_value: float) -> async def _score_value_with_llm( self, *, - prompt_target: PromptChatTarget, + prompt_target: PromptTarget, system_prompt: str, message_value: str, message_data_type: PromptDataType, diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index 2aab7c264e..b0a311bc86 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -11,7 +11,7 @@ from pyrit.exceptions import PyritException, pyrit_target_retry from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, MessagePiece, Score -from pyrit.prompt_target import GandalfLevel, PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, GandalfLevel, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -35,7 +35,7 @@ def __init__( self, *, level: GandalfLevel, - chat_target: PromptChatTarget, + chat_target: PromptTarget, validator: Optional[ScorerPromptValidator] = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: @@ -44,13 +44,14 @@ def __init__( Args: level (GandalfLevel): The Gandalf challenge level to score against. - chat_target (PromptChatTarget): The chat target used for password extraction. + chat_target (PromptTarget): The chat target used for password extraction. validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to text data type validator. score_aggregator (TrueFalseAggregatorFunc): Aggregator for combining scores. Defaults to TrueFalseScoreAggregator.OR. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) self._prompt_target = chat_target self._defender = level.value self._endpoint = "https://gandalf-api.lakera.ai/api/guess-password" diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index 7102ba3af6..d9444e948d 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -11,7 +11,7 @@ from pyrit.common.path import SCORER_CONTENT_CLASSIFIERS_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -41,7 +41,7 @@ class SelfAskCategoryScorer(TrueFalseScorer): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, content_classifier_path: Union[str, Path], score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, validator: Optional[ScorerPromptValidator] = None, @@ -58,6 +58,7 @@ def __init__( """ super().__init__(score_aggregator=score_aggregator, validator=validator or self._DEFAULT_VALIDATOR) + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) self._prompt_target = chat_target content_classifier_path = verify_and_resolve_path(content_classifier_path) diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 44bb362748..5f7f02f71d 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, UnvalidatedScore - from pyrit.prompt_target import PromptChatTarget + from pyrit.prompt_target import PromptTarget class SelfAskGeneralTrueFalseScorer(TrueFalseScorer): @@ -32,7 +32,7 @@ class SelfAskGeneralTrueFalseScorer(TrueFalseScorer): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, system_prompt_format_string: str, prompt_format_string: Optional[str] = None, category: Optional[str] = None, diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index bf1c017dde..69b453be24 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -18,7 +18,7 @@ import pathlib from pyrit.models import MessagePiece, Score, UnvalidatedScore - from pyrit.prompt_target import PromptChatTarget + from pyrit.prompt_target import PromptTarget class SelfAskQuestionAnswerScorer(SelfAskTrueFalseScorer): @@ -37,7 +37,7 @@ class SelfAskQuestionAnswerScorer(SelfAskTrueFalseScorer): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, true_false_question_path: Optional[pathlib.Path] = None, validator: Optional[ScorerPromptValidator] = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index cf9b30f1d8..67e35c17c5 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -8,7 +8,7 @@ from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -68,7 +68,7 @@ class SelfAskRefusalScorer(TrueFalseScorer): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, refusal_system_prompt_path: Union[RefusalScorerPaths, Path, str] = RefusalScorerPaths.OBJECTIVE_STRICT, prompt_format_string: Optional[str] = None, validator: Optional[ScorerPromptValidator] = None, @@ -102,6 +102,7 @@ def __init__( super().__init__(score_aggregator=score_aggregator, validator=validator or self._DEFAULT_VALIDATOR) + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) self._prompt_target = chat_target # Resolve the system prompt path diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index da1054274d..56698267f9 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -12,7 +12,7 @@ from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -97,7 +97,7 @@ class SelfAskTrueFalseScorer(TrueFalseScorer): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, true_false_question_path: Optional[Union[str, Path]] = None, true_false_question: Optional[TrueFalseQuestion] = None, true_false_system_prompt_path: Optional[Union[str, Path]] = None, @@ -122,6 +122,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) + CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) self._prompt_target = chat_target if true_false_question_path and true_false_question: diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index c86e741e9c..a87948c330 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -94,7 +94,13 @@ def mock_chat_target() -> MagicMock: @pytest.fixture def mock_prompt_target() -> MagicMock: """Create a mock prompt target (non-chat) for testing.""" + from pyrit.prompt_target.common.target_capabilities import TargetCapabilities + from pyrit.prompt_target.common.target_configuration import TargetConfiguration + target = MagicMock(spec=PromptTarget) + target.configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False), + ) target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py index 3baeaf463c..11d019ed76 100644 --- a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py @@ -684,8 +684,14 @@ class TestValueErrorGuards: """Test that incompatible attacks raise ValueError for single-turn targets.""" def _make_single_turn_target(self): + from pyrit.prompt_target.common.target_capabilities import TargetCapabilities + from pyrit.prompt_target.common.target_configuration import TargetConfiguration + target = MagicMock() target.capabilities.supports_multi_turn = False + target.configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True), + ) target.get_identifier.return_value = MagicMock() return target @@ -706,52 +712,34 @@ def _make_scoring_config(self): @pytest.mark.asyncio async def test_crescendo_raises_for_single_turn_target(self): - from pyrit.executor.attack.multi_turn.crescendo import CrescendoAttack, CrescendoAttackContext + from pyrit.executor.attack.multi_turn.crescendo import CrescendoAttack target = self._make_single_turn_target() - attack = CrescendoAttack( - objective_target=target, - attack_adversarial_config=self._make_adversarial_config(), - attack_scoring_config=self._make_scoring_config(), - ) - - context = CrescendoAttackContext( - params=AttackParameters(objective="Test"), - ) - with pytest.raises(ValueError, match="CrescendoAttack requires a multi-turn target"): - await attack._setup_async(context=context) + with pytest.raises(ValueError, match="supports_multi_turn"): + CrescendoAttack( + objective_target=target, + attack_adversarial_config=self._make_adversarial_config(), + attack_scoring_config=self._make_scoring_config(), + ) @pytest.mark.asyncio async def test_multi_prompt_sending_raises_for_single_turn_target(self): from pyrit.executor.attack.multi_turn.multi_prompt_sending import MultiPromptSendingAttack target = self._make_single_turn_target() - attack = MultiPromptSendingAttack(objective_target=target) - - context = MultiTurnAttackContext( - params=AttackParameters(objective="Test"), - ) - with pytest.raises(ValueError, match="MultiPromptSendingAttack requires a multi-turn target"): - await attack._setup_async(context=context) + with pytest.raises(ValueError, match="supports_multi_turn"): + MultiPromptSendingAttack(objective_target=target) @pytest.mark.asyncio async def test_chunked_request_raises_for_single_turn_target(self): - from pyrit.executor.attack.multi_turn.chunked_request import ( - ChunkedRequestAttack, - ChunkedRequestAttackContext, - ) + from pyrit.executor.attack.multi_turn.chunked_request import ChunkedRequestAttack target = self._make_single_turn_target() - attack = ChunkedRequestAttack(objective_target=target) - - context = ChunkedRequestAttackContext( - params=AttackParameters(objective="Test"), - ) - with pytest.raises(ValueError, match="ChunkedRequestAttack requires a multi-turn target"): - await attack._setup_async(context=context) + with pytest.raises(ValueError, match="supports_multi_turn"): + ChunkedRequestAttack(objective_target=target) @pytest.mark.usefixtures("patch_central_database") diff --git a/tests/unit/prompt_target/target/test_target_requirements.py b/tests/unit/prompt_target/target/test_target_requirements.py index 002ccf086c..a53550d3bf 100644 --- a/tests/unit/prompt_target/target/test_target_requirements.py +++ b/tests/unit/prompt_target/target/test_target_requirements.py @@ -129,3 +129,50 @@ def test_validate_mixed_adapt_and_raise(adapt_all_policy): ) with pytest.raises(ValueError, match="supports_json_output"): reqs.validate(configuration=config) + + +# --------------------------------------------------------------------------- +# required_native_capabilities +# --------------------------------------------------------------------------- + + +def test_validate_native_passes_when_supported_natively(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps) + reqs = TargetRequirements(required_native_capabilities=frozenset({CapabilityName.MULTI_TURN})) + reqs.validate(configuration=config) + + +def test_validate_native_raises_when_only_adapted(adapt_all_policy): + # multi_turn is missing but ADAPT — acceptable for required_capabilities + # but not for required_native_capabilities. + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + reqs = TargetRequirements(required_native_capabilities=frozenset({CapabilityName.MULTI_TURN})) + with pytest.raises(ValueError, match="natively support 'supports_multi_turn'"): + reqs.validate(configuration=config) + + +def test_validate_native_and_adapted_mixed(adapt_all_policy): + # system_prompt adapted (OK for required_capabilities); multi_turn required + # natively (FAIL — only adapted). + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + reqs = TargetRequirements( + required_capabilities=frozenset({CapabilityName.SYSTEM_PROMPT}), + required_native_capabilities=frozenset({CapabilityName.MULTI_TURN}), + ) + with pytest.raises(ValueError, match="natively support 'supports_multi_turn'"): + reqs.validate(configuration=config) + + +def test_validate_native_collects_multiple_violations(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps) + reqs = TargetRequirements( + required_native_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) + ) + with pytest.raises(ValueError, match="2 required capability") as exc_info: + reqs.validate(configuration=config) + assert "supports_multi_turn" in str(exc_info.value) + assert "supports_system_prompt" in str(exc_info.value) From 229d3b93c41b117243a3177645fb826c206feee1 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Tue, 21 Apr 2026 17:37:56 -0400 Subject: [PATCH 2/8] whitespace --- pyrit/executor/attack/multi_turn/chunked_request.py | 1 - pyrit/executor/attack/multi_turn/crescendo.py | 1 - pyrit/executor/attack/multi_turn/tree_of_attacks.py | 4 +--- pyrit/prompt_target/common/target_requirements.py | 4 +--- pyrit/score/scorer.py | 2 +- 5 files changed, 3 insertions(+), 9 deletions(-) diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index d22166f644..459331366d 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -235,7 +235,6 @@ async def _setup_async(self, *, context: ChunkedRequestAttackContext) -> None: Args: context (ChunkedRequestAttackContext): The attack context containing attack parameters. """ - # Ensure the context has a session context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 7ded67b3ca..143686c2cd 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -267,7 +267,6 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None: Args: context (CrescendoAttackContext): Attack context with configuration """ - # Ensure the context has a session context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 653a90c773..a4c97f5e3b 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -1327,9 +1327,7 @@ def __init__( # TAP sets a system prompt on the adversarial target and drives a # multi-turn dialogue through it; both capabilities must be native. TargetRequirements( - required_native_capabilities=frozenset( - {CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT} - ), + required_native_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}), ).validate(configuration=self._adversarial_chat.configuration) # Load system prompts diff --git a/pyrit/prompt_target/common/target_requirements.py b/pyrit/prompt_target/common/target_requirements.py index cde3640690..f018e395b1 100644 --- a/pyrit/prompt_target/common/target_requirements.py +++ b/pyrit/prompt_target/common/target_requirements.py @@ -84,9 +84,7 @@ def _build_chat_consumer_requirements() -> TargetRequirements: from pyrit.prompt_target.common.target_capabilities import CapabilityName return TargetRequirements( - required_capabilities=frozenset( - {CapabilityName.SYSTEM_PROMPT, CapabilityName.MULTI_TURN} - ), + required_capabilities=frozenset({CapabilityName.SYSTEM_PROMPT, CapabilityName.MULTI_TURN}), ) diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index b37ef8dd94..fed289ae4c 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: from collections.abc import Sequence - from pyrit.prompt_target import PromptChatTarget, PromptTarget + from pyrit.prompt_target import PromptTarget from pyrit.score.scorer_evaluation.metrics_type import RegistryUpdateBehavior from pyrit.score.scorer_evaluation.scorer_evaluator import ( ScorerEvalDatasetFiles, From bf7fcadf0a180f8e910bd2ae2d94f845b6f67769 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 22 Apr 2026 17:23:19 -0400 Subject: [PATCH 3/8] remove required native capabilities and simplify targetrequirements --- pyrit/executor/attack/core/attack_strategy.py | 9 +- .../attack/multi_turn/chunked_request.py | 9 +- pyrit/executor/attack/multi_turn/crescendo.py | 9 +- .../attack/multi_turn/multi_prompt_sending.py | 9 +- .../attack/multi_turn/tree_of_attacks.py | 15 +- .../llm_generic_text_converter.py | 3 +- .../prompt_converter/persuasion_converter.py | 3 +- pyrit/prompt_converter/prompt_converter.py | 25 ++- .../prompt_converter/translation_converter.py | 3 +- pyrit/prompt_converter/variation_converter.py | 3 +- .../common/prompt_chat_target.py | 2 +- .../common/target_requirements.py | 92 ++------- .../score/float_scale/insecure_code_scorer.py | 3 +- .../float_scale/self_ask_likert_scorer.py | 3 +- .../float_scale/self_ask_scale_scorer.py | 3 +- pyrit/score/scorer.py | 21 +++ pyrit/score/true_false/gandalf_scorer.py | 3 +- .../true_false/self_ask_category_scorer.py | 3 +- .../true_false/self_ask_refusal_scorer.py | 3 +- .../true_false/self_ask_true_false_scorer.py | 3 +- .../target/test_target_requirements.py | 178 ++---------------- 21 files changed, 135 insertions(+), 267 deletions(-) diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index dac86d10aa..87ef16a813 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -8,7 +8,7 @@ import time from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -27,6 +27,7 @@ ConversationReference, Message, ) +from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from pyrit.executor.attack.core.attack_config import AttackScoringConfig @@ -233,6 +234,10 @@ class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Id Defines the interface for executing attacks and handling results. """ + #: Capability requirements placed on ``objective_target``. Subclasses override + #: to declare what the attack needs. Validated in ``__init__``. + target_requirements: ClassVar[TargetRequirements] = TargetRequirements() + def __init__( self, *, @@ -259,6 +264,8 @@ def __init__( ), logger=logger, ) + for capability in type(self).target_requirements.required: + objective_target.configuration.ensure_can_handle(capability=capability) self._objective_target = objective_target self._params_type = params_type # Guard so subclasses that set converters before calling super() aren't clobbered diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 459331366d..57b6c8816d 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -30,7 +30,6 @@ from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName -from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from pyrit.score import TrueFalseScorer @@ -145,9 +144,11 @@ def __init__( # Chunked request issues multiple distinct turns; history-squash # adaptation would collapse them into a single prompt. - TargetRequirements( - required_native_capabilities=frozenset({CapabilityName.MULTI_TURN}), - ).validate(configuration=objective_target.configuration) + if not objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): + raise ValueError( + "ChunkedRequestAttack requires a target that natively supports " + f"'{CapabilityName.MULTI_TURN.value}'." + ) # Store chunk configuration self._chunk_size = chunk_size diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 143686c2cd..0419e4c2f6 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -45,7 +45,6 @@ from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptChatTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName -from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import ( FloatScaleThresholdScorer, Scorer, @@ -158,9 +157,11 @@ def __init__( # Crescendo fundamentally relies on multi-turn conversation history to # gradually escalate prompts; history-squash adaptation would defeat it. - TargetRequirements( - required_native_capabilities=frozenset({CapabilityName.MULTI_TURN}), - ).validate(configuration=objective_target.configuration) + if not objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): + raise ValueError( + "CrescendoAttack requires a target that natively supports " + f"'{CapabilityName.MULTI_TURN.value}'." + ) self._memory = CentralMemory.get_memory_instance() diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index 40158b9e4d..c330492a9f 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -30,7 +30,6 @@ from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName -from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import Scorer if TYPE_CHECKING: @@ -156,9 +155,11 @@ def __init__( # Sending a sequence of prompts requires a real multi-turn target; # history-squash adaptation would collapse them into one message. - TargetRequirements( - required_native_capabilities=frozenset({CapabilityName.MULTI_TURN}), - ).validate(configuration=objective_target.configuration) + if not objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): + raise ValueError( + "MultiPromptSendingAttack requires a target that natively supports " + f"'{CapabilityName.MULTI_TURN.value}'." + ) # Initialize the converter configuration attack_converter_config = attack_converter_config or AttackConverterConfig() diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 4163a94314..7772ffb372 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -52,7 +52,6 @@ from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import PromptChatTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName -from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import ( FloatScaleThresholdScorer, Scorer, @@ -1326,9 +1325,17 @@ def __init__( self._adversarial_chat = attack_adversarial_config.target # TAP sets a system prompt on the adversarial target and drives a # multi-turn dialogue through it; both capabilities must be native. - TargetRequirements( - required_native_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}), - ).validate(configuration=self._adversarial_chat.configuration) + adversarial_config = self._adversarial_chat.configuration + missing_native = [ + capability.value + for capability in (CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT) + if not adversarial_config.includes(capability=capability) + ] + if missing_native: + raise ValueError( + "TreeOfAttacksWithPruningAttack requires an adversarial target that natively supports: " + + ", ".join(missing_native) + ) # Load system prompts self._adversarial_chat_system_prompt_path = ( diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index b8b22f127b..31614f135b 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -27,6 +27,7 @@ class LLMGenericTextConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + target_requirements = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( @@ -53,7 +54,7 @@ def __init__( Raises: ValueError: If converter_target is not provided and no default has been configured. """ - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=converter_target.configuration) + self._validate_target_requirements(target=converter_target) self._converter_target = converter_target self._system_prompt_template = system_prompt_template self._prompt_kwargs = kwargs diff --git a/pyrit/prompt_converter/persuasion_converter.py b/pyrit/prompt_converter/persuasion_converter.py index 5ff7b51c9d..85c441fd3e 100644 --- a/pyrit/prompt_converter/persuasion_converter.py +++ b/pyrit/prompt_converter/persuasion_converter.py @@ -47,6 +47,7 @@ class PersuasionConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + target_requirements = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( @@ -69,7 +70,7 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the persuasion technique is not supported or does not exist. """ - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=converter_target.configuration) + self._validate_target_requirements(target=converter_target) self.converter_target = converter_target try: diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 141076e701..dc82dbbd0e 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -6,11 +6,15 @@ import inspect import re from dataclasses import dataclass -from typing import Any, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, get_args from pyrit import prompt_converter from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.models import PromptDataType +from pyrit.prompt_target.common.target_requirements import TargetRequirements + +if TYPE_CHECKING: + from pyrit.prompt_target import PromptTarget @dataclass @@ -48,6 +52,11 @@ class PromptConverter(Identifiable): #: Tuple of output modalities supported by this converter. Subclasses must override this. SUPPORTED_OUTPUT_TYPES: tuple[PromptDataType, ...] = () + #: Capability requirements placed on the converter's target (if any). + #: Subclasses that use a target should override this and call + #: ``self._validate_target_requirements(target=converter_target)`` in ``__init__``. + target_requirements: ClassVar[TargetRequirements] = TargetRequirements() + _identifier: Optional[ComponentIdentifier] = None def __init_subclass__(cls, **kwargs: object) -> None: @@ -81,6 +90,20 @@ def __init__(self) -> None: """ super().__init__() + def _validate_target_requirements(self, *, target: "PromptTarget") -> None: + """ + Validate ``target`` against this converter's declared + :attr:`target_requirements`. + + Args: + target (PromptTarget): The target to validate. + + Raises: + ValueError: If any required capability cannot be satisfied. + """ + for capability in type(self).target_requirements.required: + target.configuration.ensure_can_handle(capability=capability) + @abc.abstractmethod async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: """ diff --git a/pyrit/prompt_converter/translation_converter.py b/pyrit/prompt_converter/translation_converter.py index 62ae0ad4c8..899772f01c 100644 --- a/pyrit/prompt_converter/translation_converter.py +++ b/pyrit/prompt_converter/translation_converter.py @@ -36,6 +36,7 @@ class TranslationConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + target_requirements = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( @@ -62,7 +63,7 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the language is not provided. """ - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=converter_target.configuration) + self._validate_target_requirements(target=converter_target) self.converter_target = converter_target # Retry strategy for the conversion diff --git a/pyrit/prompt_converter/variation_converter.py b/pyrit/prompt_converter/variation_converter.py index 8af7438202..091440da3b 100644 --- a/pyrit/prompt_converter/variation_converter.py +++ b/pyrit/prompt_converter/variation_converter.py @@ -35,6 +35,7 @@ class VariationConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + target_requirements = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( @@ -55,7 +56,7 @@ def __init__( Raises: ValueError: If converter_target is not provided and no default has been configured. """ - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=converter_target.configuration) + self._validate_target_requirements(target=converter_target) self.converter_target = converter_target # set to default strategy if not provided diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index a84488c25f..f8b7eb8133 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -73,7 +73,7 @@ def set_system_prompt( labels: Optional[dict[str, str]] = None, ) -> None: """ - Deprecated shim. Use :meth:`PromptTarget.set_system_prompt` on the base class. + Use :meth:`PromptTarget.set_system_prompt` on the base class. Retained on ``PromptChatTarget`` so subclasses that override this method continue to work. Delegates to the base-class implementation. diff --git a/pyrit/prompt_target/common/target_requirements.py b/pyrit/prompt_target/common/target_requirements.py index f018e395b1..1a8ecdee0d 100644 --- a/pyrit/prompt_target/common/target_requirements.py +++ b/pyrit/prompt_target/common/target_requirements.py @@ -4,11 +4,8 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from pyrit.prompt_target.common.target_capabilities import CapabilityName - from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.prompt_target.common.target_capabilities import CapabilityName @dataclass(frozen=True) @@ -17,80 +14,23 @@ class TargetRequirements: Declarative description of what a consumer (attack, converter, scorer) requires from a target. - Consumers define their requirements once and validate them against a - ``TargetConfiguration`` at construction time. This replaces ad-hoc - ``isinstance`` checks and scattered capability branching. + The single source of truth for capability names is the + :class:`CapabilityName` enum; this class is simply a typed wrapper + around the set of capabilities a consumer needs. - Two levels of requirement are supported: - - * ``required_capabilities`` — the target must *handle* the capability, - either natively or via an ``ADAPT`` policy (normalization pipeline). - Use this when the consumer only cares that the behavior is available - on the wire, regardless of how. - * ``required_native_capabilities`` — the target must support the - capability natively; adaptation via the normalization pipeline is - not acceptable. Use this when adaptation would defeat the consumer's - purpose (e.g. a multi-turn attack cannot run against a target whose - history is squashed into a single prompt). + Requirements are satisfied either by native support on the target or + by an ``ADAPT`` entry in the target's + :class:`CapabilityHandlingPolicy`. Consumers that cannot tolerate + adaptation should perform their own ``capabilities.includes(...)`` + check instead of declaring a requirement here. """ - # Capabilities the consumer requires, native or adapted. - required_capabilities: frozenset[CapabilityName] = field(default_factory=frozenset) - - # Capabilities the consumer requires to be natively supported. Adaptation - # via the normalization pipeline is not acceptable for these capabilities. - required_native_capabilities: frozenset[CapabilityName] = field(default_factory=frozenset) - - def validate(self, *, configuration: TargetConfiguration) -> None: - """ - Validate that the target configuration can satisfy all requirements. - - For ``required_capabilities`` this delegates to - ``TargetConfiguration.ensure_can_handle``, which accepts either - native support or an ``ADAPT`` policy. For - ``required_native_capabilities`` this checks - ``TargetConfiguration.includes`` directly — adaptation is not - acceptable. All violations are collected and reported in a single - ``ValueError``. - - Args: - configuration (TargetConfiguration): The target configuration to validate against. - - Raises: - ValueError: If any required capability cannot be satisfied. - """ - errors: list[str] = [] - for capability in sorted(self.required_capabilities, key=lambda c: c.value): - try: - configuration.ensure_can_handle(capability=capability) - except ValueError as exc: - errors.append(str(exc)) - for capability in sorted(self.required_native_capabilities, key=lambda c: c.value): - if not configuration.includes(capability=capability): - errors.append( - f"Target does not natively support '{capability.value}' " - "and adaptation is not acceptable for this consumer." - ) - if errors: - raise ValueError( - f"Target does not satisfy {len(errors)} required capability(ies):\n" - + "\n".join(f" - {e}" for e in errors) - ) - - -def _build_chat_consumer_requirements() -> TargetRequirements: - # Imported lazily to avoid a hard import cycle with target_capabilities at - # module load time (target_requirements only type-checks CapabilityName). - from pyrit.prompt_target.common.target_capabilities import CapabilityName - - return TargetRequirements( - required_capabilities=frozenset({CapabilityName.SYSTEM_PROMPT, CapabilityName.MULTI_TURN}), - ) + required: frozenset[CapabilityName] = field(default_factory=frozenset) -# Requirements declared by code paths that historically demanded a -# ``PromptChatTarget`` (converters and scorers that call ``set_system_prompt`` -# and then send a short conversation). Adaptation via the normalization -# pipeline is acceptable here — the consumer only needs the *behavior*, not -# native support. -CHAT_CONSUMER_REQUIREMENTS: TargetRequirements = _build_chat_consumer_requirements() +# Shared requirement used by scorers and converters that set a system prompt +# and drive a short multi-turn conversation. Adaptation is acceptable: the +# consumer only needs the behavior on the wire, not native support. +CHAT_CONSUMER_REQUIREMENTS = TargetRequirements( + required=frozenset({CapabilityName.SYSTEM_PROMPT, CapabilityName.MULTI_TURN}), +) diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index 85fdfe86e6..77d0c468e4 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -21,6 +21,7 @@ class InsecureCodeScorer(FloatScaleScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) + target_requirements = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -40,7 +41,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) + self._validate_target_requirements(target=chat_target) self._prompt_target = chat_target if not system_prompt_path: diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index 60a87ade05..57bae35cde 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -173,6 +173,7 @@ class SelfAskLikertScorer(FloatScaleScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) + target_requirements = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -203,7 +204,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) + self._validate_target_requirements(target=chat_target) self._prompt_target = chat_target self._likert_scale = likert_scale diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index d986293354..69abc13e9e 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -39,6 +39,7 @@ class SystemPaths(enum.Enum): supported_data_types=["text"], is_objective_required=True, ) + target_requirements = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -61,7 +62,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) + self._validate_target_requirements(target=chat_target) self._prompt_target = chat_target if not system_prompt_path: diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 5ae7dcb5c9..e7a50390b3 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -12,6 +12,7 @@ from typing import ( TYPE_CHECKING, Any, + ClassVar, Optional, Union, cast, @@ -35,6 +36,7 @@ UnvalidatedScore, ) from pyrit.prompt_target.batch_helper import batch_task_async +from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from collections.abc import Sequence @@ -59,6 +61,11 @@ class Scorer(Identifiable, abc.ABC): # Specifies glob patterns for datasets and a result file name. evaluation_file_mapping: Optional[ScorerEvalDatasetFiles] = None + #: Capability requirements placed on the scorer's chat target (if any). + #: Subclasses that use a chat target should override this and call + #: ``self._validate_target_requirements(target=chat_target)`` in ``__init__``. + target_requirements: ClassVar[TargetRequirements] = TargetRequirements() + _identifier: Optional[ComponentIdentifier] = None def __init__(self, *, validator: ScorerPromptValidator): @@ -70,6 +77,20 @@ def __init__(self, *, validator: ScorerPromptValidator): """ self._validator = validator + def _validate_target_requirements(self, *, target: PromptTarget) -> None: + """ + Validate ``target`` against this scorer's declared + :attr:`target_requirements`. + + Args: + target (PromptTarget): The target to validate. + + Raises: + ValueError: If any required capability cannot be satisfied. + """ + for capability in type(self).target_requirements.required: + target.configuration.ensure_can_handle(capability=capability) + def get_identifier(self) -> ComponentIdentifier: """ Get the scorer's identifier with eval_hash always attached. diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index b0a311bc86..5884f87dc1 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -30,6 +30,7 @@ class GandalfScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) + target_requirements = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -51,7 +52,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) + self._validate_target_requirements(target=chat_target) self._prompt_target = chat_target self._defender = level.value self._endpoint = "https://gandalf-api.lakera.ai/api/guess-password" diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index d9444e948d..330615e2b6 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -37,6 +37,7 @@ class SelfAskCategoryScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator() + target_requirements = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -58,7 +59,7 @@ def __init__( """ super().__init__(score_aggregator=score_aggregator, validator=validator or self._DEFAULT_VALIDATOR) - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) + self._validate_target_requirements(target=chat_target) self._prompt_target = chat_target content_classifier_path = verify_and_resolve_path(content_classifier_path) diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index 67e35c17c5..d509e63eee 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -64,6 +64,7 @@ class SelfAskRefusalScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator() + target_requirements = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -102,7 +103,7 @@ def __init__( super().__init__(score_aggregator=score_aggregator, validator=validator or self._DEFAULT_VALIDATOR) - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) + self._validate_target_requirements(target=chat_target) self._prompt_target = chat_target # Resolve the system prompt path diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index da1a7c2255..ec8bf57da5 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -93,6 +93,7 @@ class SelfAskTrueFalseScorer(TrueFalseScorer): _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator( supported_data_types=["text", "image_path"], ) + target_requirements = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -122,7 +123,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) - CHAT_CONSUMER_REQUIREMENTS.validate(configuration=chat_target.configuration) + self._validate_target_requirements(target=chat_target) self._prompt_target = chat_target if true_false_question_path and true_false_question: diff --git a/tests/unit/prompt_target/target/test_target_requirements.py b/tests/unit/prompt_target/target/test_target_requirements.py index a53550d3bf..763f4573d1 100644 --- a/tests/unit/prompt_target/target/test_target_requirements.py +++ b/tests/unit/prompt_target/target/test_target_requirements.py @@ -3,176 +3,32 @@ import pytest -from pyrit.prompt_target.common.target_capabilities import ( - CapabilityHandlingPolicy, +from pyrit.prompt_target import ( + CHAT_CONSUMER_REQUIREMENTS, CapabilityName, - TargetCapabilities, - UnsupportedCapabilityBehavior, + TargetRequirements, ) -from pyrit.prompt_target.common.target_configuration import TargetConfiguration -from pyrit.prompt_target.common.target_requirements import TargetRequirements -@pytest.fixture -def adapt_all_policy(): - return CapabilityHandlingPolicy( - behaviors={ - CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, - CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, - CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, - } - ) - - -# --------------------------------------------------------------------------- -# Construction -# --------------------------------------------------------------------------- - - -def test_init_default_has_empty_capabilities(): - reqs = TargetRequirements() - assert reqs.required_capabilities == frozenset() - - -def test_init_with_capabilities(): - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) - ) - assert CapabilityName.MULTI_TURN in reqs.required_capabilities - assert CapabilityName.SYSTEM_PROMPT in reqs.required_capabilities - - -# --------------------------------------------------------------------------- -# validate — all pass -# --------------------------------------------------------------------------- - - -def test_validate_passes_when_target_supports_all_natively(): - caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) - ) - reqs.validate(configuration=config) - - -def test_validate_passes_when_policy_is_adapt(adapt_all_policy): - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) - ) - reqs.validate(configuration=config) - - -def test_validate_passes_with_empty_requirements(): - caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements() - reqs.validate(configuration=config) +def test_default_requirements_require_nothing(): + assert TargetRequirements().required == frozenset() -# --------------------------------------------------------------------------- -# validate — failures -# --------------------------------------------------------------------------- - - -def test_validate_raises_when_capability_missing_and_no_policy(): - # EDITABLE_HISTORY has no normalizer and no handling policy — validate raises. - caps = TargetCapabilities(supports_editable_history=False, supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements(required_capabilities=frozenset({CapabilityName.EDITABLE_HISTORY})) - with pytest.raises(ValueError, match="supports_editable_history"): - reqs.validate(configuration=config) - - -def test_validate_raises_when_capability_missing_and_policy_raise(adapt_all_policy): - # json_output is missing and the policy is RAISE — validate raises. - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False, supports_json_output=False) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - reqs = TargetRequirements(required_capabilities=frozenset({CapabilityName.JSON_OUTPUT})) - with pytest.raises(ValueError, match="supports_json_output"): - reqs.validate(configuration=config) - - -def test_validate_collects_all_unsatisfied_capabilities(adapt_all_policy): - """When multiple capabilities are missing, validate reports all violations.""" - caps = TargetCapabilities( - supports_multi_turn=False, - supports_system_prompt=False, - supports_json_output=False, - supports_editable_history=False, - ) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - # json_output => RAISE, editable_history => no policy (raises) - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.JSON_OUTPUT, CapabilityName.EDITABLE_HISTORY}) - ) - with pytest.raises(ValueError, match="2 required capability") as exc_info: - reqs.validate(configuration=config) - assert "supports_json_output" in str(exc_info.value) - assert "supports_editable_history" in str(exc_info.value) - - -def test_validate_mixed_adapt_and_raise(adapt_all_policy): - """One capability adapts but another raises — validate should raise.""" - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False, supports_json_output=False) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - # multi_turn and system_prompt => ADAPT (OK), json_output => RAISE (fail) +def test_construction_from_frozenset(): reqs = TargetRequirements( - required_capabilities=frozenset( - {CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT, CapabilityName.JSON_OUTPUT} - ) + required=frozenset({CapabilityName.MULTI_TURN, CapabilityName.JSON_OUTPUT}), ) - with pytest.raises(ValueError, match="supports_json_output"): - reqs.validate(configuration=config) + assert reqs.required == {CapabilityName.MULTI_TURN, CapabilityName.JSON_OUTPUT} -# --------------------------------------------------------------------------- -# required_native_capabilities -# --------------------------------------------------------------------------- +def test_chat_consumer_requirements_shape(): + assert CHAT_CONSUMER_REQUIREMENTS.required == { + CapabilityName.SYSTEM_PROMPT, + CapabilityName.MULTI_TURN, + } -def test_validate_native_passes_when_supported_natively(): - caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements(required_native_capabilities=frozenset({CapabilityName.MULTI_TURN})) - reqs.validate(configuration=config) - - -def test_validate_native_raises_when_only_adapted(adapt_all_policy): - # multi_turn is missing but ADAPT — acceptable for required_capabilities - # but not for required_native_capabilities. - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - reqs = TargetRequirements(required_native_capabilities=frozenset({CapabilityName.MULTI_TURN})) - with pytest.raises(ValueError, match="natively support 'supports_multi_turn'"): - reqs.validate(configuration=config) - - -def test_validate_native_and_adapted_mixed(adapt_all_policy): - # system_prompt adapted (OK for required_capabilities); multi_turn required - # natively (FAIL — only adapted). - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.SYSTEM_PROMPT}), - required_native_capabilities=frozenset({CapabilityName.MULTI_TURN}), - ) - with pytest.raises(ValueError, match="natively support 'supports_multi_turn'"): - reqs.validate(configuration=config) - - -def test_validate_native_collects_multiple_violations(): - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements( - required_native_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) - ) - with pytest.raises(ValueError, match="2 required capability") as exc_info: - reqs.validate(configuration=config) - assert "supports_multi_turn" in str(exc_info.value) - assert "supports_system_prompt" in str(exc_info.value) +def test_requirements_are_frozen(): + reqs = TargetRequirements(required=frozenset({CapabilityName.MULTI_TURN})) + with pytest.raises(Exception): + reqs.required = frozenset() # type: ignore[misc] From c00285c7e73a4710fb5e9c77ec4046db2e67e98b Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 22 Apr 2026 17:48:25 -0400 Subject: [PATCH 4/8] move validation func --- pyrit/executor/attack/core/attack_strategy.py | 5 ++-- .../llm_generic_text_converter.py | 4 ++-- .../prompt_converter/persuasion_converter.py | 4 ++-- pyrit/prompt_converter/prompt_converter.py | 23 +++---------------- .../prompt_converter/translation_converter.py | 4 ++-- pyrit/prompt_converter/variation_converter.py | 4 ++-- .../common/target_requirements.py | 18 +++++++++++++++ .../score/float_scale/insecure_code_scorer.py | 4 ++-- .../float_scale/self_ask_likert_scorer.py | 4 ++-- .../float_scale/self_ask_scale_scorer.py | 4 ++-- pyrit/score/scorer.py | 18 ++------------- pyrit/score/true_false/gandalf_scorer.py | 4 ++-- .../true_false/self_ask_category_scorer.py | 4 ++-- .../true_false/self_ask_refusal_scorer.py | 4 ++-- .../true_false/self_ask_true_false_scorer.py | 4 ++-- 15 files changed, 47 insertions(+), 61 deletions(-) diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 87ef16a813..2c33a768ff 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -236,7 +236,7 @@ class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Id #: Capability requirements placed on ``objective_target``. Subclasses override #: to declare what the attack needs. Validated in ``__init__``. - target_requirements: ClassVar[TargetRequirements] = TargetRequirements() + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() def __init__( self, @@ -264,8 +264,7 @@ def __init__( ), logger=logger, ) - for capability in type(self).target_requirements.required: - objective_target.configuration.ensure_can_handle(capability=capability) + type(self).TARGET_REQUIREMENTS.validate(target=objective_target) self._objective_target = objective_target self._params_type = params_type # Guard so subclasses that set converters before calling super() aren't clobbered diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index 31614f135b..59808827d7 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -27,7 +27,7 @@ class LLMGenericTextConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( @@ -54,7 +54,7 @@ def __init__( Raises: ValueError: If converter_target is not provided and no default has been configured. """ - self._validate_target_requirements(target=converter_target) + type(self).TARGET_REQUIREMENTS.validate(target=converter_target) self._converter_target = converter_target self._system_prompt_template = system_prompt_template self._prompt_kwargs = kwargs diff --git a/pyrit/prompt_converter/persuasion_converter.py b/pyrit/prompt_converter/persuasion_converter.py index 85c441fd3e..363405c8bd 100644 --- a/pyrit/prompt_converter/persuasion_converter.py +++ b/pyrit/prompt_converter/persuasion_converter.py @@ -47,7 +47,7 @@ class PersuasionConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( @@ -70,7 +70,7 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the persuasion technique is not supported or does not exist. """ - self._validate_target_requirements(target=converter_target) + type(self).TARGET_REQUIREMENTS.validate(target=converter_target) self.converter_target = converter_target try: diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index dc82dbbd0e..b245a0c846 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -6,16 +6,13 @@ import inspect import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, get_args +from typing import Any, ClassVar, Optional, Union, get_args from pyrit import prompt_converter from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.models import PromptDataType from pyrit.prompt_target.common.target_requirements import TargetRequirements -if TYPE_CHECKING: - from pyrit.prompt_target import PromptTarget - @dataclass class ConverterResult: @@ -54,8 +51,8 @@ class PromptConverter(Identifiable): #: Capability requirements placed on the converter's target (if any). #: Subclasses that use a target should override this and call - #: ``self._validate_target_requirements(target=converter_target)`` in ``__init__``. - target_requirements: ClassVar[TargetRequirements] = TargetRequirements() + #: ``type(self).TARGET_REQUIREMENTS.validate(target=converter_target)`` in ``__init__``. + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() _identifier: Optional[ComponentIdentifier] = None @@ -90,20 +87,6 @@ def __init__(self) -> None: """ super().__init__() - def _validate_target_requirements(self, *, target: "PromptTarget") -> None: - """ - Validate ``target`` against this converter's declared - :attr:`target_requirements`. - - Args: - target (PromptTarget): The target to validate. - - Raises: - ValueError: If any required capability cannot be satisfied. - """ - for capability in type(self).target_requirements.required: - target.configuration.ensure_can_handle(capability=capability) - @abc.abstractmethod async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: """ diff --git a/pyrit/prompt_converter/translation_converter.py b/pyrit/prompt_converter/translation_converter.py index 899772f01c..0228a0265c 100644 --- a/pyrit/prompt_converter/translation_converter.py +++ b/pyrit/prompt_converter/translation_converter.py @@ -36,7 +36,7 @@ class TranslationConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( @@ -63,7 +63,7 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the language is not provided. """ - self._validate_target_requirements(target=converter_target) + type(self).TARGET_REQUIREMENTS.validate(target=converter_target) self.converter_target = converter_target # Retry strategy for the conversion diff --git a/pyrit/prompt_converter/variation_converter.py b/pyrit/prompt_converter/variation_converter.py index 091440da3b..e7d35e76ee 100644 --- a/pyrit/prompt_converter/variation_converter.py +++ b/pyrit/prompt_converter/variation_converter.py @@ -35,7 +35,7 @@ class VariationConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( @@ -56,7 +56,7 @@ def __init__( Raises: ValueError: If converter_target is not provided and no default has been configured. """ - self._validate_target_requirements(target=converter_target) + type(self).TARGET_REQUIREMENTS.validate(target=converter_target) self.converter_target = converter_target # set to default strategy if not provided diff --git a/pyrit/prompt_target/common/target_requirements.py b/pyrit/prompt_target/common/target_requirements.py index 1a8ecdee0d..34595b9af8 100644 --- a/pyrit/prompt_target/common/target_requirements.py +++ b/pyrit/prompt_target/common/target_requirements.py @@ -4,9 +4,13 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import TYPE_CHECKING from pyrit.prompt_target.common.target_capabilities import CapabilityName +if TYPE_CHECKING: + from pyrit.prompt_target.common.prompt_target import PromptTarget + @dataclass(frozen=True) class TargetRequirements: @@ -27,6 +31,20 @@ class TargetRequirements: required: frozenset[CapabilityName] = field(default_factory=frozenset) + def validate(self, *, target: PromptTarget) -> None: + """ + Validate that ``target`` can satisfy every required capability. + + Args: + target (PromptTarget): The target to validate against. + + Raises: + ValueError: If any required capability is not supported natively + and has no ``ADAPT`` entry in the target's policy. + """ + for capability in self.required: + target.configuration.ensure_can_handle(capability=capability) + # Shared requirement used by scorers and converters that set a system prompt # and drive a short multi-turn conversation. Adaptation is acceptable: the diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index 77d0c468e4..eb20cebcb9 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -21,7 +21,7 @@ class InsecureCodeScorer(FloatScaleScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -41,7 +41,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) - self._validate_target_requirements(target=chat_target) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target if not system_prompt_path: diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index 57bae35cde..2b140033b9 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -173,7 +173,7 @@ class SelfAskLikertScorer(FloatScaleScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -204,7 +204,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) - self._validate_target_requirements(target=chat_target) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target self._likert_scale = likert_scale diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 69abc13e9e..5e7de81c69 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -39,7 +39,7 @@ class SystemPaths(enum.Enum): supported_data_types=["text"], is_objective_required=True, ) - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -62,7 +62,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) - self._validate_target_requirements(target=chat_target) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target if not system_prompt_path: diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index e7a50390b3..42f50f4e33 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -63,8 +63,8 @@ class Scorer(Identifiable, abc.ABC): #: Capability requirements placed on the scorer's chat target (if any). #: Subclasses that use a chat target should override this and call - #: ``self._validate_target_requirements(target=chat_target)`` in ``__init__``. - target_requirements: ClassVar[TargetRequirements] = TargetRequirements() + #: ``type(self).TARGET_REQUIREMENTS.validate(target=chat_target)`` in ``__init__``. + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() _identifier: Optional[ComponentIdentifier] = None @@ -77,20 +77,6 @@ def __init__(self, *, validator: ScorerPromptValidator): """ self._validator = validator - def _validate_target_requirements(self, *, target: PromptTarget) -> None: - """ - Validate ``target`` against this scorer's declared - :attr:`target_requirements`. - - Args: - target (PromptTarget): The target to validate. - - Raises: - ValueError: If any required capability cannot be satisfied. - """ - for capability in type(self).target_requirements.required: - target.configuration.ensure_can_handle(capability=capability) - def get_identifier(self) -> ComponentIdentifier: """ Get the scorer's identifier with eval_hash always attached. diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index 5884f87dc1..a3b838f190 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -30,7 +30,7 @@ class GandalfScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -52,7 +52,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) - self._validate_target_requirements(target=chat_target) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target self._defender = level.value self._endpoint = "https://gandalf-api.lakera.ai/api/guess-password" diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index 330615e2b6..043d0d6b19 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -37,7 +37,7 @@ class SelfAskCategoryScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator() - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -59,7 +59,7 @@ def __init__( """ super().__init__(score_aggregator=score_aggregator, validator=validator or self._DEFAULT_VALIDATOR) - self._validate_target_requirements(target=chat_target) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target content_classifier_path = verify_and_resolve_path(content_classifier_path) diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index d509e63eee..ecbb899fe6 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -64,7 +64,7 @@ class SelfAskRefusalScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator() - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -103,7 +103,7 @@ def __init__( super().__init__(score_aggregator=score_aggregator, validator=validator or self._DEFAULT_VALIDATOR) - self._validate_target_requirements(target=chat_target) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target # Resolve the system prompt path diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index ec8bf57da5..ffc357d5f5 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -93,7 +93,7 @@ class SelfAskTrueFalseScorer(TrueFalseScorer): _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator( supported_data_types=["text", "image_path"], ) - target_requirements = CHAT_CONSUMER_REQUIREMENTS + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -123,7 +123,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) - self._validate_target_requirements(target=chat_target) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target if true_false_question_path and true_false_question: From 8a43590e43b3694efda8e66280688c7ac21174fb Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 22 Apr 2026 18:37:33 -0400 Subject: [PATCH 5/8] remove remaining promptchattarget and unused requirements --- .../attack/component/conversation_manager.py | 10 ++- pyrit/executor/attack/core/attack_strategy.py | 8 +- .../attack/multi_turn/chunked_request.py | 3 +- pyrit/executor/attack/multi_turn/crescendo.py | 13 ++- .../attack/multi_turn/tree_of_attacks.py | 15 ++-- .../self_ask_general_float_scale_scorer.py | 7 +- .../self_ask_general_true_false_scorer.py | 7 +- .../self_ask_question_answer_scorer.py | 4 +- .../target/test_target_requirements.py | 82 +++++++++++++++++++ 9 files changed, 120 insertions(+), 29 deletions(-) diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 8556125318..2a3561c780 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -20,7 +20,6 @@ ) from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget -from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName if TYPE_CHECKING: @@ -243,7 +242,7 @@ def get_last_message( def set_system_prompt( self, *, - target: PromptChatTarget, + target: PromptTarget, conversation_id: str, system_prompt: str, labels: Optional[dict[str, str]] = None, @@ -252,11 +251,16 @@ def set_system_prompt( Set or update the system prompt for a conversation. Args: - target: The chat target to set the system prompt on. + target: The target to set the system prompt on. Must handle the + ``SYSTEM_PROMPT`` capability (natively or via an ``ADAPT`` policy). conversation_id: Unique identifier for the conversation. system_prompt: The system prompt text. labels: Optional labels to associate with the system prompt. + + Raises: + ValueError: If ``target`` cannot handle the ``SYSTEM_PROMPT`` capability. """ + target.configuration.ensure_can_handle(capability=CapabilityName.SYSTEM_PROMPT) target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 2c33a768ff..dac86d10aa 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -8,7 +8,7 @@ import time from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, overload from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -27,7 +27,6 @@ ConversationReference, Message, ) -from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from pyrit.executor.attack.core.attack_config import AttackScoringConfig @@ -234,10 +233,6 @@ class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Id Defines the interface for executing attacks and handling results. """ - #: Capability requirements placed on ``objective_target``. Subclasses override - #: to declare what the attack needs. Validated in ``__init__``. - TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() - def __init__( self, *, @@ -264,7 +259,6 @@ def __init__( ), logger=logger, ) - type(self).TARGET_REQUIREMENTS.validate(target=objective_target) self._objective_target = objective_target self._params_type = params_type # Guard so subclasses that set converters before calling super() aren't clobbered diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 57b6c8816d..a6bc17dcfe 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -146,8 +146,7 @@ def __init__( # adaptation would collapse them into a single prompt. if not objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): raise ValueError( - "ChunkedRequestAttack requires a target that natively supports " - f"'{CapabilityName.MULTI_TURN.value}'." + f"ChunkedRequestAttack requires a target that natively supports '{CapabilityName.MULTI_TURN.value}'." ) # Store chunk configuration diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 0419e4c2f6..957f209b1d 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -43,7 +43,7 @@ SeedPrompt, ) from pyrit.prompt_normalizer import PromptNormalizer -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName from pyrit.score import ( FloatScaleThresholdScorer, @@ -122,7 +122,7 @@ class CrescendoAttack(MultiTurnAttackStrategy[CrescendoAttackContext, CrescendoA def __init__( self, *, - objective_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] attack_adversarial_config: AttackAdversarialConfig, attack_converter_config: Optional[AttackConverterConfig] = None, attack_scoring_config: Optional[AttackScoringConfig] = None, @@ -135,7 +135,8 @@ def __init__( Initialize the Crescendo attack strategy. Args: - objective_target (PromptChatTarget): The target system to attack. Must be a PromptChatTarget. + objective_target (PromptTarget): The target system to attack. Must natively + support multi-turn conversations. attack_adversarial_config (AttackAdversarialConfig): Configuration for the adversarial component, including the adversarial chat target and optional system prompt path. attack_converter_config (Optional[AttackConverterConfig]): Configuration for attack converters, @@ -149,8 +150,7 @@ def __init__( application by role, message normalization, and non-chat target behavior. Raises: - ValueError: If objective_target is not a PromptChatTarget, or does not - natively support multi-turn conversations. + ValueError: If ``objective_target`` does not natively support multi-turn conversations. """ # Initialize base class super().__init__(objective_target=objective_target, logger=logger, context_type=CrescendoAttackContext) @@ -159,8 +159,7 @@ def __init__( # gradually escalate prompts; history-squash adaptation would defeat it. if not objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): raise ValueError( - "CrescendoAttack requires a target that natively supports " - f"'{CapabilityName.MULTI_TURN.value}'." + f"CrescendoAttack requires a target that natively supports '{CapabilityName.MULTI_TURN.value}'." ) self._memory = CentralMemory.get_memory_instance() diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 7772ffb372..39af0921db 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -50,7 +50,7 @@ SeedPrompt, ) from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptChatTarget, PromptTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName from pyrit.score import ( FloatScaleThresholdScorer, @@ -258,7 +258,7 @@ class _TreeOfAttacksNode: def __init__( self, *, - objective_target: PromptChatTarget, + objective_target: PromptTarget, adversarial_chat: PromptChatTarget, adversarial_chat_seed_prompt: SeedPrompt, adversarial_chat_prompt_template: SeedPrompt, @@ -280,7 +280,7 @@ def __init__( Initialize a tree node. Args: - objective_target (PromptChatTarget): The target to attack. + objective_target (PromptTarget): The target to attack. adversarial_chat (PromptChatTarget): The chat target for generating adversarial prompts. adversarial_chat_seed_prompt (SeedPrompt): The seed prompt for the first turn. adversarial_chat_prompt_template (SeedPrompt): The template for subsequent turns. @@ -1255,7 +1255,7 @@ class TreeOfAttacksWithPruningAttack(AttackStrategy[TAPAttackContext, TAPAttackR def __init__( self, *, - objective_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] attack_adversarial_config: AttackAdversarialConfig, attack_converter_config: Optional[AttackConverterConfig] = None, attack_scoring_config: Optional[AttackScoringConfig] = None, @@ -1272,7 +1272,7 @@ def __init__( Initialize the Tree of Attacks with Pruning attack strategy. Args: - objective_target (PromptChatTarget): The target system to attack. + objective_target (PromptTarget): The target system to attack. attack_adversarial_config (AttackAdversarialConfig): Configuration for the adversarial chat component. attack_converter_config (Optional[AttackConverterConfig]): Configuration for attack converters. Defaults to None. @@ -1294,7 +1294,8 @@ def __init__( Raises: ValueError: If attack_scoring_config uses a non-FloatScaleThresholdScorer objective scorer, - if target is not PromptChatTarget, or if parameters are invalid. + if the adversarial target does not natively support the capabilities TAP needs, + or if parameters are invalid. """ # Validate tree parameters if tree_depth < 1: @@ -1869,7 +1870,7 @@ def _create_attack_node( generate adversarial prompts and evaluate responses. """ node = _TreeOfAttacksNode( - objective_target=cast("PromptChatTarget", self._objective_target), + objective_target=self._objective_target, adversarial_chat=self._adversarial_chat, adversarial_chat_seed_prompt=self._adversarial_chat_seed_prompt, adversarial_chat_system_seed_prompt=self._adversarial_chat_system_seed_prompt, diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index b2796106ff..9750a936a3 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -24,6 +25,7 @@ class SelfAskGeneralFloatScaleScorer(FloatScaleScorer): supported_data_types=["text"], is_objective_required=True, ) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -52,7 +54,9 @@ def __init__( in the response, the provided `category` argument will be applied. Args: - chat_target (PromptChatTarget): The chat target used to score. + chat_target (PromptTarget): The chat target used to score. Must satisfy + ``CHAT_CONSUMER_REQUIREMENTS`` (system-prompt + multi-turn capabilities, + possibly via normalization-pipeline adaptation). system_prompt_format_string (str): System prompt template with placeholders for objective, prompt, and message_piece. prompt_format_string (Optional[str]): User prompt template with the same placeholders. @@ -72,6 +76,7 @@ def __init__( ValueError: If min_value is greater than max_value. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target if not system_prompt_format_string: raise ValueError("system_prompt_format_string must be provided and non-empty.") diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 5f7f02f71d..ffd34c4dd2 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -28,6 +29,7 @@ class SelfAskGeneralTrueFalseScorer(TrueFalseScorer): supported_data_types=["text"], is_objective_required=False, ) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, @@ -55,7 +57,9 @@ def __init__( in the response, the provided `category` argument will be applied. Args: - chat_target (PromptChatTarget): The chat target used to score. + chat_target (PromptTarget): The chat target used to score. Must satisfy + ``CHAT_CONSUMER_REQUIREMENTS`` (system-prompt + multi-turn capabilities, + possibly via normalization-pipeline adaptation). system_prompt_format_string (str): System prompt template with placeholders for objective, task (alias of objective), prompt, and message_piece. prompt_format_string (Optional[str]): User prompt template with the same placeholders. @@ -74,6 +78,7 @@ def __init__( ValueError: If system_prompt_format_string is not provided or empty. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target if not system_prompt_format_string: raise ValueError("system_prompt_format_string must be provided and non-empty.") diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index 69b453be24..b52ed9fa79 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -46,7 +46,9 @@ def __init__( Initialize the SelfAskQuestionAnswerScorer object. Args: - chat_target (PromptChatTarget): The chat target to use for the scorer. + chat_target (PromptTarget): The chat target to use for the scorer. Must satisfy + ``CHAT_CONSUMER_REQUIREMENTS`` (system-prompt + multi-turn capabilities, + possibly via normalization-pipeline adaptation). true_false_question_path (Optional[pathlib.Path]): The path to the true/false question file. Defaults to None, which uses the default question_answering.yaml file. validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. diff --git a/tests/unit/prompt_target/target/test_target_requirements.py b/tests/unit/prompt_target/target/test_target_requirements.py index 763f4573d1..0774afdda2 100644 --- a/tests/unit/prompt_target/target/test_target_requirements.py +++ b/tests/unit/prompt_target/target/test_target_requirements.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from unittest.mock import MagicMock + import pytest from pyrit.prompt_target import ( @@ -8,6 +10,18 @@ CapabilityName, TargetRequirements, ) +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration + + +def _make_target(*, configuration: TargetConfiguration) -> MagicMock: + target = MagicMock() + target.configuration = configuration + return target def test_default_requirements_require_nothing(): @@ -32,3 +46,71 @@ def test_requirements_are_frozen(): reqs = TargetRequirements(required=frozenset({CapabilityName.MULTI_TURN})) with pytest.raises(Exception): reqs.required = frozenset() # type: ignore[misc] + + +def test_validate_passes_on_native_support(): + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=True, + ), + ), + ) + + CHAT_CONSUMER_REQUIREMENTS.validate(target=target) + + +def test_validate_passes_when_policy_is_adapt(): + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=False, + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + }, + ), + ), + ) + + CHAT_CONSUMER_REQUIREMENTS.validate(target=target) + + +def test_validate_raises_when_capability_neither_native_nor_adapt(): + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + }, + ), + ), + ) + + with pytest.raises(ValueError, match=CapabilityName.SYSTEM_PROMPT.value): + CHAT_CONSUMER_REQUIREMENTS.validate(target=target) + + +def test_validate_empty_required_always_passes(): + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities(), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + }, + ), + ), + ) + + TargetRequirements().validate(target=target) From b1edf253462a972f159ac256ad2647fc2151e955 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 22 Apr 2026 19:33:21 -0400 Subject: [PATCH 6/8] remove lingering chat target ref --- .../attack/component/conversation_manager.py | 8 ++--- pyrit/executor/attack/core/attack_strategy.py | 8 ++++- pyrit/executor/attack/multi_turn/crescendo.py | 17 ++++++----- .../attack/multi_turn/multi_prompt_sending.py | 18 ++++++----- .../common/prompt_chat_target.py | 20 ------------- .../common/target_requirements.py | 30 ++++++++++++++----- 6 files changed, 53 insertions(+), 48 deletions(-) diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 2a3561c780..f1d86a4a35 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -288,7 +288,7 @@ async def initialize_context_async( 3. Updates context.executed_turns for multi-turn attacks 4. Sets context.next_message if there's an unanswered user message - For PromptChatTarget: + For chat-capable PromptTarget: - Adds prepended messages to memory with simulated_assistant role - All messages get new UUIDs @@ -311,7 +311,7 @@ async def initialize_context_async( Raises: ValueError: If conversation_id is empty, or if prepended_conversation - requires a PromptChatTarget but target is not one. + requires a chat-capable PromptTarget but target is not one. """ if not conversation_id: raise ValueError("conversation_id cannot be empty") @@ -374,8 +374,8 @@ async def _handle_non_chat_target_async( if config.non_chat_target_behavior == "raise": raise ValueError( - "prepended_conversation requires the objective target to be a PromptChatTarget. " - "Non-chat objective targets do not support conversation history. " + "prepended_conversation requires the objective target to be a chat-capable " + "PromptTarget.Non-chat objective targets do not support conversation history. " "Use PrependedConversationConfig with non_chat_target_behavior='normalize_first_turn' " "to normalize the conversation into the first message instead." ) diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index dac86d10aa..efe2dcca8e 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -8,7 +8,7 @@ import time from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -27,6 +27,7 @@ ConversationReference, Message, ) +from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from pyrit.executor.attack.core.attack_config import AttackScoringConfig @@ -233,6 +234,10 @@ class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Id Defines the interface for executing attacks and handling results. """ + #: Capability requirements placed on ``objective_target``. Subclasses + #: override to declare what the attack needs. Validated in ``__init__``. + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() + def __init__( self, *, @@ -259,6 +264,7 @@ def __init__( ), logger=logger, ) + type(self).TARGET_REQUIREMENTS.validate(target=objective_target) self._objective_target = objective_target self._params_type = params_type # Guard so subclasses that set converters before calling super() aren't clobbered diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 957f209b1d..3f0632e9f8 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -45,6 +45,7 @@ from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import ( FloatScaleThresholdScorer, Scorer, @@ -113,6 +114,15 @@ class CrescendoAttack(MultiTurnAttackStrategy[CrescendoAttackContext, CrescendoA You can learn more about the Crescendo attack [@russinovich2024crescendo]. """ + # Crescendo fundamentally relies on multi-turn conversation history to + # gradually escalate prompts; history-squash adaptation would collapse the + # conversation into a single prompt and silently break the attack's + # semantics. Declare MULTI_TURN as ``native_required`` so adaptation is + # rejected at construction time. + TARGET_REQUIREMENTS = TargetRequirements( + native_required=frozenset({CapabilityName.MULTI_TURN}), + ) + # Default system prompt template path for Crescendo attack DEFAULT_ADVERSARIAL_CHAT_SYSTEM_PROMPT_TEMPLATE_PATH: Path = ( Path(EXECUTOR_SEED_PROMPT_PATH) / "crescendo" / "crescendo_variant_1.yaml" @@ -155,13 +165,6 @@ def __init__( # Initialize base class super().__init__(objective_target=objective_target, logger=logger, context_type=CrescendoAttackContext) - # Crescendo fundamentally relies on multi-turn conversation history to - # gradually escalate prompts; history-squash adaptation would defeat it. - if not objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): - raise ValueError( - f"CrescendoAttack requires a target that natively supports '{CapabilityName.MULTI_TURN.value}'." - ) - self._memory = CentralMemory.get_memory_instance() # Initialize converter configuration diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index c330492a9f..3239c5568e 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -30,6 +30,7 @@ from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import Scorer if TYPE_CHECKING: @@ -124,6 +125,15 @@ class MultiPromptSendingAttack(MultiTurnAttackStrategy[MultiTurnAttackContext[An and multiple scorer types for comprehensive evaluation. """ + # Sending a sequence of distinct prompts depends on the target maintaining + # conversation state between them. History-squash adaptation would collapse + # them into one message and silently break the attack's sequencing + # semantics. Declare MULTI_TURN as ``native_required`` so adaptation is + # rejected at construction time. + TARGET_REQUIREMENTS = TargetRequirements( + native_required=frozenset({CapabilityName.MULTI_TURN}), + ) + @apply_defaults def __init__( self, @@ -153,14 +163,6 @@ def __init__( params_type=MultiPromptSendingAttackParameters, ) - # Sending a sequence of prompts requires a real multi-turn target; - # history-squash adaptation would collapse them into one message. - if not objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): - raise ValueError( - "MultiPromptSendingAttack requires a target that natively supports " - f"'{CapabilityName.MULTI_TURN.value}'." - ) - # Initialize the converter configuration attack_converter_config = attack_converter_config or AttackConverterConfig() self._request_converters = attack_converter_config.request_converters diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index f8b7eb8133..22ece7f603 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -64,26 +64,6 @@ def __init__( custom_capabilities=custom_capabilities, ) - def set_system_prompt( - self, - *, - system_prompt: str, - conversation_id: str, - attack_identifier: Optional[ComponentIdentifier] = None, - labels: Optional[dict[str, str]] = None, - ) -> None: - """ - Use :meth:`PromptTarget.set_system_prompt` on the base class. - - Retained on ``PromptChatTarget`` so subclasses that override this method - continue to work. Delegates to the base-class implementation. - """ - super().set_system_prompt( - system_prompt=system_prompt, - conversation_id=conversation_id, - attack_identifier=attack_identifier, - labels=labels, - ) def is_response_format_json(self, message_piece: MessagePiece) -> bool: """ diff --git a/pyrit/prompt_target/common/target_requirements.py b/pyrit/prompt_target/common/target_requirements.py index 34595b9af8..5ffee9ae9e 100644 --- a/pyrit/prompt_target/common/target_requirements.py +++ b/pyrit/prompt_target/common/target_requirements.py @@ -22,26 +22,40 @@ class TargetRequirements: :class:`CapabilityName` enum; this class is simply a typed wrapper around the set of capabilities a consumer needs. - Requirements are satisfied either by native support on the target or - by an ``ADAPT`` entry in the target's - :class:`CapabilityHandlingPolicy`. Consumers that cannot tolerate - adaptation should perform their own ``capabilities.includes(...)`` - check instead of declaring a requirement here. + Two tiers of requirement are supported: + + * ``required`` \u2014 satisfied either by native support on the target or + by an ``ADAPT`` entry in the target's + :class:`CapabilityHandlingPolicy`. Use this when the consumer only + needs the behavior to appear on the wire. + * ``native_required`` \u2014 must be natively supported. Adaptation is + rejected. Use this when adaptation would silently change the + consumer's semantics (e.g. an attack that depends on the target + remembering prior turns, where history-squash normalization would + collapse the conversation into a single prompt). """ required: frozenset[CapabilityName] = field(default_factory=frozenset) + native_required: frozenset[CapabilityName] = field(default_factory=frozenset) def validate(self, *, target: PromptTarget) -> None: """ - Validate that ``target`` can satisfy every required capability. + Validate that ``target`` can satisfy every declared requirement. Args: target (PromptTarget): The target to validate against. Raises: - ValueError: If any required capability is not supported natively - and has no ``ADAPT`` entry in the target's policy. + ValueError: If any ``native_required`` capability is not natively + supported, or if any ``required`` capability is not supported + natively and has no ``ADAPT`` entry in the target's policy. """ + for capability in self.native_required: + if not target.configuration.includes(capability=capability): + raise ValueError( + f"Target must natively support '{capability.value}'; " + "adaptation is not acceptable for this consumer." + ) for capability in self.required: target.configuration.ensure_can_handle(capability=capability) From 1915c8263369f4fcf69b8c65c2a6db9dfee8dc20 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 22 Apr 2026 20:24:35 -0400 Subject: [PATCH 7/8] fix docstrings --- .../attack/component/prepended_conversation_config.py | 6 +++--- .../attack/multi_turn/multi_turn_attack_strategy.py | 4 +++- pyrit/executor/attack/multi_turn/tree_of_attacks.py | 2 +- pyrit/prompt_converter/denylist_converter.py | 2 +- .../malicious_question_generator_converter.py | 2 +- pyrit/prompt_converter/math_prompt_converter.py | 2 +- pyrit/prompt_converter/noise_converter.py | 2 +- pyrit/prompt_converter/random_translation_converter.py | 2 +- pyrit/prompt_converter/tense_converter.py | 2 +- pyrit/prompt_converter/tone_converter.py | 2 +- .../prompt_converter/toxic_sentence_generator_converter.py | 2 +- pyrit/prompt_target/common/prompt_chat_target.py | 6 ++---- pyrit/prompt_target/common/prompt_target.py | 6 +++--- pyrit/score/float_scale/insecure_code_scorer.py | 2 +- pyrit/score/float_scale/self_ask_likert_scorer.py | 2 +- pyrit/score/float_scale/self_ask_scale_scorer.py | 2 +- pyrit/score/scorer.py | 2 +- pyrit/score/true_false/self_ask_category_scorer.py | 2 +- pyrit/score/true_false/self_ask_refusal_scorer.py | 2 +- pyrit/score/true_false/self_ask_true_false_scorer.py | 2 +- .../attack/multi_turn/test_supports_multi_turn_attacks.py | 3 +++ 21 files changed, 30 insertions(+), 27 deletions(-) diff --git a/pyrit/executor/attack/component/prepended_conversation_config.py b/pyrit/executor/attack/component/prepended_conversation_config.py index c78ffad767..fddeae5371 100644 --- a/pyrit/executor/attack/component/prepended_conversation_config.py +++ b/pyrit/executor/attack/component/prepended_conversation_config.py @@ -22,7 +22,7 @@ class PrependedConversationConfig: This class provides control over: - Which message roles should have request converters applied - How to normalize conversation history for non-chat objective targets - - What to do when the objective target is not a PromptChatTarget + - What to do when the objective target is not a chat-capable PromptTarget """ # Roles for which request converters should be applied to prepended messages. @@ -36,13 +36,13 @@ class PrependedConversationConfig: # ConversationContextNormalizer is used that produces "Turn N: User/Assistant" format. message_normalizer: Optional[MessageStringNormalizer] = None - # Behavior when the target is a PromptTarget but not a PromptChatTarget: + # Behavior when the target is a PromptTarget but not a chat-capable PromptTarget: # - "normalize_first_turn": Normalize the prepended conversation into a string and # store it in ConversationState.normalized_prepended_context. This context will be # prepended to the first message sent to the target. Uses objective_target_context_normalizer # if provided, otherwise falls back to ConversationContextNormalizer. # - "raise": Raise a ValueError. Use this when prepended conversation history must be - # maintained by the target (i.e., target must be a PromptChatTarget). + # maintained by the target (i.e., target must be a chat-capable PromptTarget). non_chat_target_behavior: Literal["normalize_first_turn", "raise"] = "normalize_first_turn" def get_message_normalizer(self) -> MessageStringNormalizer: diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 04c8084f7b..995ec696b0 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -19,6 +19,8 @@ from pyrit.memory import CentralMemory from pyrit.models import ConversationReference, ConversationType +from pyrit.prompt_target.common.target_capabilities import CapabilityName + if TYPE_CHECKING: from pyrit.models import ( Message, @@ -117,7 +119,7 @@ def _rotate_conversation_for_single_turn_target( Args: context: The current attack context. """ - if self._objective_target.capabilities.supports_multi_turn: + if self._objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): return if context.executed_turns == 0: diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 39af0921db..cbd604ccb1 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -781,7 +781,7 @@ def duplicate(self) -> "_TreeOfAttacksNode": # For single-turn targets, duplicate only the system messages (e.g., system prompt # from prepended conversation) so the target retains its configuration without # carrying over attack turn history that would cause validation errors. - if self._objective_target.capabilities.supports_multi_turn: + if self._objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): duplicate_node.objective_target_conversation_id = self._memory.duplicate_conversation( conversation_id=self.objective_target_conversation_id ) diff --git a/pyrit/prompt_converter/denylist_converter.py b/pyrit/prompt_converter/denylist_converter.py index 969472c5fb..7cbc6f5e9e 100644 --- a/pyrit/prompt_converter/denylist_converter.py +++ b/pyrit/prompt_converter/denylist_converter.py @@ -19,7 +19,7 @@ class DenylistConverter(LLMGenericTextConverter): """ Replaces forbidden words or phrases in a prompt with synonyms using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults diff --git a/pyrit/prompt_converter/malicious_question_generator_converter.py b/pyrit/prompt_converter/malicious_question_generator_converter.py index fb9e225261..5725fff9c4 100644 --- a/pyrit/prompt_converter/malicious_question_generator_converter.py +++ b/pyrit/prompt_converter/malicious_question_generator_converter.py @@ -19,7 +19,7 @@ class MaliciousQuestionGeneratorConverter(LLMGenericTextConverter): """ Generates malicious questions using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults diff --git a/pyrit/prompt_converter/math_prompt_converter.py b/pyrit/prompt_converter/math_prompt_converter.py index 8d809a6661..a3520b190c 100644 --- a/pyrit/prompt_converter/math_prompt_converter.py +++ b/pyrit/prompt_converter/math_prompt_converter.py @@ -19,7 +19,7 @@ class MathPromptConverter(LLMGenericTextConverter): """ Converts natural language instructions into symbolic mathematics problems using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults diff --git a/pyrit/prompt_converter/noise_converter.py b/pyrit/prompt_converter/noise_converter.py index a89e2d85c7..86c5375773 100644 --- a/pyrit/prompt_converter/noise_converter.py +++ b/pyrit/prompt_converter/noise_converter.py @@ -20,7 +20,7 @@ class NoiseConverter(LLMGenericTextConverter): """ Injects noise errors into a conversation using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults diff --git a/pyrit/prompt_converter/random_translation_converter.py b/pyrit/prompt_converter/random_translation_converter.py index 4711e0e0d5..7e11810323 100644 --- a/pyrit/prompt_converter/random_translation_converter.py +++ b/pyrit/prompt_converter/random_translation_converter.py @@ -22,7 +22,7 @@ class RandomTranslationConverter(LLMGenericTextConverter, WordLevelConverter): """ Translates each individual word in a prompt to a random language using an LLM. - An existing ``PromptChatTarget`` is used to perform the translation (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the translation (like Azure OpenAI). """ SUPPORTED_INPUT_TYPES = ("text",) diff --git a/pyrit/prompt_converter/tense_converter.py b/pyrit/prompt_converter/tense_converter.py index 66a64d0158..eede7adef9 100644 --- a/pyrit/prompt_converter/tense_converter.py +++ b/pyrit/prompt_converter/tense_converter.py @@ -19,7 +19,7 @@ class TenseConverter(LLMGenericTextConverter): """ Converts a conversation to a different tense using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults diff --git a/pyrit/prompt_converter/tone_converter.py b/pyrit/prompt_converter/tone_converter.py index 69ee355aaf..4a6d0e859e 100644 --- a/pyrit/prompt_converter/tone_converter.py +++ b/pyrit/prompt_converter/tone_converter.py @@ -19,7 +19,7 @@ class ToneConverter(LLMGenericTextConverter): """ Converts a conversation to a different tone using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults diff --git a/pyrit/prompt_converter/toxic_sentence_generator_converter.py b/pyrit/prompt_converter/toxic_sentence_generator_converter.py index 67922b4c03..636e50ad8d 100644 --- a/pyrit/prompt_converter/toxic_sentence_generator_converter.py +++ b/pyrit/prompt_converter/toxic_sentence_generator_converter.py @@ -23,7 +23,7 @@ class ToxicSentenceGeneratorConverter(LLMGenericTextConverter): """ Generates toxic sentence starters using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). Based on Project Moonshot's attack module that generates toxic sentences to test LLM safety guardrails: diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index 22ece7f603..6fe9c910fd 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -3,11 +3,10 @@ from typing import Optional -from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.prompt_target import PromptTarget -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration @@ -64,7 +63,6 @@ def __init__( custom_capabilities=custom_capabilities, ) - def is_response_format_json(self, message_piece: MessagePiece) -> bool: """ Check if the response format is JSON and ensure the target supports it. @@ -98,7 +96,7 @@ def _get_json_response_config(self, *, message_piece: MessagePiece) -> _JsonResp """ config = _JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata) - if config.enabled and not self.capabilities.supports_json_output: + if config.enabled and not self.configuration.includes(capability=CapabilityName.JSON_OUTPUT): target_name = self.get_identifier().class_name raise ValueError(f"This target {target_name} does not support JSON response format.") diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index de8ef64b35..5c9712d21b 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -9,7 +9,7 @@ from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import Message, MessagePiece -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration, resolve_configuration_compat logger = logging.getLogger(__name__) @@ -178,7 +178,7 @@ def _validate_request(self, *, normalized_conversation: list[Message]) -> None: custom_configuration_message = ( "If your target does support this, set the custom_configuration parameter accordingly." ) - if not self.capabilities.supports_multi_message_pieces and n_pieces != 1: + if not self.configuration.includes(capability=CapabilityName.MULTI_MESSAGE_PIECES) and n_pieces != 1: raise ValueError( f"This target only supports a single message piece. Received: {n_pieces} pieces. " f"{custom_configuration_message}" @@ -194,7 +194,7 @@ def _validate_request(self, *, normalized_conversation: list[Message]) -> None: f"{custom_configuration_message}" ) - if not self.capabilities.supports_multi_turn and len(normalized_conversation) > 1: + if not self.configuration.includes(capability=CapabilityName.MULTI_TURN) and len(normalized_conversation) > 1: raise ValueError(f"This target only supports a single turn conversation. {custom_configuration_message}") async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]: diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index eb20cebcb9..f2245c3bee 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -34,7 +34,7 @@ def __init__( Initialize the Insecure Code Scorer. Args: - chat_target (PromptChatTarget): The target to use for scoring code security. + chat_target (PromptTarget): The target to use for scoring code security. system_prompt_path (Optional[Union[str, Path]]): Path to the YAML file containing the system prompt. Defaults to the default insecure code scoring prompt if not provided. validator (Optional[ScorerPromptValidator]): Custom validator for the scorer. Defaults to None. diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index 2b140033b9..3e019a01b4 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -188,7 +188,7 @@ def __init__( Initialize the SelfAskLikertScorer. Args: - chat_target (PromptChatTarget): The chat target to use for scoring. + chat_target (PromptTarget): The chat target to use for scoring. likert_scale (Optional[LikertScalePaths]): The Likert scale configuration to use for scoring. custom_likert_path (Optional[Path]): Path to a custom YAML file containing the Likert scale definition. This allows users to use their own Likert scales without modifying the code, as long as diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 5e7de81c69..636faf6bc0 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -53,7 +53,7 @@ def __init__( Initialize the SelfAskScaleScorer. Args: - chat_target (PromptChatTarget): The chat target to use for scoring. + chat_target (PromptTarget): The chat target to use for scoring. scale_arguments_path (Optional[Union[Path, str]]): Path to the YAML file containing scale definitions. Defaults to TREE_OF_ATTACKS_SCALE if not provided. system_prompt_path (Optional[Union[Path, str]]): Path to the YAML file containing the system prompt. diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 42f50f4e33..d2bc150791 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -523,7 +523,7 @@ async def _score_value_with_llm( description fields. Args: - prompt_target (PromptChatTarget): The target LLM to send the message to. + prompt_target (PromptTarget): The target LLM to send the message to. system_prompt (str): The system-level prompt that guides the behavior of the target LLM. message_value (str): The actual value or content to be scored by the LLM (e.g., text, image path, audio path). diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index 043d0d6b19..c0aa35ecee 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -51,7 +51,7 @@ def __init__( Initialize a new instance of the SelfAskCategoryScorer class. Args: - chat_target (PromptChatTarget): The chat target to interact with. + chat_target (PromptTarget): The chat target to interact with. content_classifier_path (Union[str, Path]): The path to the classifier YAML file. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index ecbb899fe6..4379dd76ed 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -79,7 +79,7 @@ def __init__( Initialize the SelfAskRefusalScorer. Args: - chat_target (PromptChatTarget): The endpoint that will be used to score the prompt. + chat_target (PromptTarget): The endpoint that will be used to score the prompt. refusal_system_prompt_path (Union[RefusalScorerPaths, Path, str]): The path to the system prompt to use for refusal detection. Can be a RefusalScorerPaths enum value, a Path, or a string path. Defaults to RefusalScorerPaths.OBJECTIVE_STRICT. diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index ffc357d5f5..6685037a7b 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -109,7 +109,7 @@ def __init__( Initialize the SelfAskTrueFalseScorer. Args: - chat_target (PromptChatTarget): The chat target to interact with. + chat_target (PromptTarget): The chat target to interact with. true_false_question_path (Optional[Union[str, Path]]): The path to the true/false question file. true_false_question (Optional[TrueFalseQuestion]): The true/false question object. true_false_system_prompt_path (Optional[Union[str, Path]]): The path to the system prompt file. diff --git a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py index 11d019ed76..9de79ab384 100644 --- a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py @@ -27,6 +27,7 @@ def _make_strategy(*, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn + target.configuration.includes.return_value = supports_multi_turn target.get_identifier.return_value = MagicMock() with patch.multiple( @@ -376,6 +377,7 @@ def _make_tap_node(self, *, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn + target.configuration.includes.return_value = supports_multi_turn target.get_identifier.return_value = MagicMock() adversarial_chat = MagicMock() @@ -752,6 +754,7 @@ def _make_tap_node(self, *, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn + target.configuration.includes.return_value = supports_multi_turn target.get_identifier.return_value = MagicMock() adversarial_chat = MagicMock() From d2cb04b7fc0993c471b202186490728202323278 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 22 Apr 2026 20:59:39 -0400 Subject: [PATCH 8/8] correct chat definition --- .../attack/component/conversation_manager.py | 4 ++-- pyrit/executor/attack/multi_turn/crescendo.py | 12 ++++++------ .../attack/multi_turn/multi_turn_attack_strategy.py | 1 - pyrit/prompt_target/common/target_requirements.py | 2 +- pyrit/prompt_target/openai/openai_chat_target.py | 1 + 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index f1d86a4a35..b47bb06936 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -330,7 +330,7 @@ async def initialize_context_async( # prepended multi-message conversation as-is — route them to the # single-string fallback path. Type identity (PromptChatTarget) is a # legacy signal for this; capability-based routing is the durable form. - is_chat_target = target.configuration.includes(capability=CapabilityName.MULTI_TURN) + is_chat_target = target.configuration.includes(capability=CapabilityName.EDITABLE_HISTORY) if not is_chat_target: return await self._handle_non_chat_target_async( context=context, @@ -375,7 +375,7 @@ async def _handle_non_chat_target_async( if config.non_chat_target_behavior == "raise": raise ValueError( "prepended_conversation requires the objective target to be a chat-capable " - "PromptTarget.Non-chat objective targets do not support conversation history. " + "PromptTarget. Non-chat objective targets do not support conversation history. " "Use PrependedConversationConfig with non_chat_target_behavior='normalize_first_turn' " "to normalize the conversation into the first message instead." ) diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 3f0632e9f8..f2c44c257f 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -114,13 +114,13 @@ class CrescendoAttack(MultiTurnAttackStrategy[CrescendoAttackContext, CrescendoA You can learn more about the Crescendo attack [@russinovich2024crescendo]. """ - # Crescendo fundamentally relies on multi-turn conversation history to + # Crescendo fundamentally relies on editable conversation history to # gradually escalate prompts; history-squash adaptation would collapse the # conversation into a single prompt and silently break the attack's - # semantics. Declare MULTI_TURN as ``native_required`` so adaptation is + # semantics. Declare EDITABLE_HISTORY as ``native_required`` so adaptation is # rejected at construction time. TARGET_REQUIREMENTS = TargetRequirements( - native_required=frozenset({CapabilityName.MULTI_TURN}), + required=frozenset({CapabilityName.EDITABLE_HISTORY}), ) # Default system prompt template path for Crescendo attack @@ -145,8 +145,8 @@ def __init__( Initialize the Crescendo attack strategy. Args: - objective_target (PromptTarget): The target system to attack. Must natively - support multi-turn conversations. + objective_target (PromptTarget): The target system to attack. Must + support editable conversation history. attack_adversarial_config (AttackAdversarialConfig): Configuration for the adversarial component, including the adversarial chat target and optional system prompt path. attack_converter_config (Optional[AttackConverterConfig]): Configuration for attack converters, @@ -160,7 +160,7 @@ def __init__( application by role, message normalization, and non-chat target behavior. Raises: - ValueError: If ``objective_target`` does not natively support multi-turn conversations. + ValueError: If ``objective_target`` does not natively support editable history. """ # Initialize base class super().__init__(objective_target=objective_target, logger=logger, context_type=CrescendoAttackContext) diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 995ec696b0..007845d864 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -18,7 +18,6 @@ ) from pyrit.memory import CentralMemory from pyrit.models import ConversationReference, ConversationType - from pyrit.prompt_target.common.target_capabilities import CapabilityName if TYPE_CHECKING: diff --git a/pyrit/prompt_target/common/target_requirements.py b/pyrit/prompt_target/common/target_requirements.py index 5ffee9ae9e..f0909dcc9d 100644 --- a/pyrit/prompt_target/common/target_requirements.py +++ b/pyrit/prompt_target/common/target_requirements.py @@ -64,5 +64,5 @@ def validate(self, *, target: PromptTarget) -> None: # and drive a short multi-turn conversation. Adaptation is acceptable: the # consumer only needs the behavior on the wire, not native support. CHAT_CONSUMER_REQUIREMENTS = TargetRequirements( - required=frozenset({CapabilityName.SYSTEM_PROMPT, CapabilityName.MULTI_TURN}), + required=frozenset({CapabilityName.EDITABLE_HISTORY, CapabilityName.MULTI_TURN}), ) diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index db059d6807..5e45136043 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -72,6 +72,7 @@ class OpenAIChatTarget(OpenAITarget, PromptChatTarget): supports_json_output=True, supports_multi_message_pieces=True, supports_system_prompt=True, + supports_editable_history=True, input_modalities=frozenset( {frozenset({"text"}), frozenset({"image_path"}), frozenset({"text", "image_path"})} ),