diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 7a27cb5666..b47bb06936 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -20,7 +20,7 @@ ) from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget -from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName if TYPE_CHECKING: from pyrit.executor.attack.core import AttackContext @@ -242,7 +242,7 @@ def get_last_message( def set_system_prompt( self, *, - target: PromptChatTarget, + target: PromptTarget, conversation_id: str, system_prompt: str, labels: Optional[dict[str, str]] = None, @@ -251,11 +251,16 @@ def set_system_prompt( Set or update the system prompt for a conversation. Args: - target: The chat target to set the system prompt on. + target: The target to set the system prompt on. Must handle the + ``SYSTEM_PROMPT`` capability (natively or via an ``ADAPT`` policy). conversation_id: Unique identifier for the conversation. system_prompt: The system prompt text. labels: Optional labels to associate with the system prompt. + + Raises: + ValueError: If ``target`` cannot handle the ``SYSTEM_PROMPT`` capability. """ + target.configuration.ensure_can_handle(capability=CapabilityName.SYSTEM_PROMPT) target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, @@ -283,7 +288,7 @@ async def initialize_context_async( 3. Updates context.executed_turns for multi-turn attacks 4. Sets context.next_message if there's an unanswered user message - For PromptChatTarget: + For chat-capable PromptTarget: - Adds prepended messages to memory with simulated_assistant role - All messages get new UUIDs @@ -306,7 +311,7 @@ async def initialize_context_async( Raises: ValueError: If conversation_id is empty, or if prepended_conversation - requires a PromptChatTarget but target is not one. + requires a chat-capable PromptTarget but target is not one. """ if not conversation_id: raise ValueError("conversation_id cannot be empty") @@ -321,8 +326,11 @@ async def initialize_context_async( logger.debug(f"No prepended conversation for context initialization: {conversation_id}") return state - # Handle target type compatibility - is_chat_target = isinstance(target, PromptChatTarget) + # Targets that don't natively support multi-turn history cannot consume a + # prepended multi-message conversation as-is — route them to the + # single-string fallback path. Type identity (PromptChatTarget) is a + # legacy signal for this; capability-based routing is the durable form. + is_chat_target = target.configuration.includes(capability=CapabilityName.EDITABLE_HISTORY) if not is_chat_target: return await self._handle_non_chat_target_async( context=context, @@ -366,8 +374,8 @@ async def _handle_non_chat_target_async( if config.non_chat_target_behavior == "raise": raise ValueError( - "prepended_conversation requires the objective target to be a PromptChatTarget. " - "Non-chat objective targets do not support conversation history. " + "prepended_conversation requires the objective target to be a chat-capable " + "PromptTarget. Non-chat objective targets do not support conversation history. " "Use PrependedConversationConfig with non_chat_target_behavior='normalize_first_turn' " "to normalize the conversation into the first message instead." ) diff --git a/pyrit/executor/attack/component/prepended_conversation_config.py b/pyrit/executor/attack/component/prepended_conversation_config.py index c78ffad767..fddeae5371 100644 --- a/pyrit/executor/attack/component/prepended_conversation_config.py +++ b/pyrit/executor/attack/component/prepended_conversation_config.py @@ -22,7 +22,7 @@ class PrependedConversationConfig: This class provides control over: - Which message roles should have request converters applied - How to normalize conversation history for non-chat objective targets - - What to do when the objective target is not a PromptChatTarget + - What to do when the objective target is not a chat-capable PromptTarget """ # Roles for which request converters should be applied to prepended messages. @@ -36,13 +36,13 @@ class PrependedConversationConfig: # ConversationContextNormalizer is used that produces "Turn N: User/Assistant" format. message_normalizer: Optional[MessageStringNormalizer] = None - # Behavior when the target is a PromptTarget but not a PromptChatTarget: + # Behavior when the target is a PromptTarget but not a chat-capable PromptTarget: # - "normalize_first_turn": Normalize the prepended conversation into a string and # store it in ConversationState.normalized_prepended_context. This context will be # prepended to the first message sent to the target. Uses objective_target_context_normalizer # if provided, otherwise falls back to ConversationContextNormalizer. # - "raise": Raise a ValueError. Use this when prepended conversation history must be - # maintained by the target (i.e., target must be a PromptChatTarget). + # maintained by the target (i.e., target must be a chat-capable PromptTarget). non_chat_target_behavior: Literal["normalize_first_turn", "raise"] = "normalize_first_turn" def get_message_normalizer(self) -> MessageStringNormalizer: diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index dac86d10aa..efe2dcca8e 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -8,7 +8,7 @@ import time from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -27,6 +27,7 @@ ConversationReference, Message, ) +from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from pyrit.executor.attack.core.attack_config import AttackScoringConfig @@ -233,6 +234,10 @@ class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Id Defines the interface for executing attacks and handling results. """ + #: Capability requirements placed on ``objective_target``. Subclasses + #: override to declare what the attack needs. Validated in ``__init__``. + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() + def __init__( self, *, @@ -259,6 +264,7 @@ def __init__( ), logger=logger, ) + type(self).TARGET_REQUIREMENTS.validate(target=objective_target) self._objective_target = objective_target self._params_type = params_type # Guard so subclasses that set converters before calling super() aren't clobbered diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 1a70c89195..a6bc17dcfe 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -29,6 +29,7 @@ ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName if TYPE_CHECKING: from pyrit.score import TrueFalseScorer @@ -141,6 +142,13 @@ def __init__( params_type=ChunkedRequestAttackParameters, ) + # Chunked request issues multiple distinct turns; history-squash + # adaptation would collapse them into a single prompt. + if not objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): + raise ValueError( + f"ChunkedRequestAttack requires a target that natively supports '{CapabilityName.MULTI_TURN.value}'." + ) + # Store chunk configuration self._chunk_size = chunk_size self._total_length = total_length @@ -226,16 +234,7 @@ async def _setup_async(self, *, context: ChunkedRequestAttackContext) -> None: Args: context (ChunkedRequestAttackContext): The attack context containing attack parameters. - - Raises: - ValueError: If the objective target does not support multi-turn conversations. """ - if not self._objective_target.capabilities.supports_multi_turn: - raise ValueError( - "ChunkedRequestAttack requires a multi-turn target. " - "The objective target does not support multi-turn conversations." - ) - # Ensure the context has a session context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 4a180d5df3..f2c44c257f 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -43,7 +43,9 @@ SeedPrompt, ) from pyrit.prompt_normalizer import PromptNormalizer -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import ( FloatScaleThresholdScorer, Scorer, @@ -112,6 +114,15 @@ class CrescendoAttack(MultiTurnAttackStrategy[CrescendoAttackContext, CrescendoA You can learn more about the Crescendo attack [@russinovich2024crescendo]. """ + # Crescendo fundamentally relies on editable conversation history to + # gradually escalate prompts; history-squash adaptation would collapse the + # conversation into a single prompt and silently break the attack's + # semantics. Declare EDITABLE_HISTORY as ``native_required`` so adaptation is + # rejected at construction time. + TARGET_REQUIREMENTS = TargetRequirements( + required=frozenset({CapabilityName.EDITABLE_HISTORY}), + ) + # Default system prompt template path for Crescendo attack DEFAULT_ADVERSARIAL_CHAT_SYSTEM_PROMPT_TEMPLATE_PATH: Path = ( Path(EXECUTOR_SEED_PROMPT_PATH) / "crescendo" / "crescendo_variant_1.yaml" @@ -121,7 +132,7 @@ class CrescendoAttack(MultiTurnAttackStrategy[CrescendoAttackContext, CrescendoA def __init__( self, *, - objective_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] attack_adversarial_config: AttackAdversarialConfig, attack_converter_config: Optional[AttackConverterConfig] = None, attack_scoring_config: Optional[AttackScoringConfig] = None, @@ -134,7 +145,8 @@ def __init__( Initialize the Crescendo attack strategy. Args: - objective_target (PromptChatTarget): The target system to attack. Must be a PromptChatTarget. + objective_target (PromptTarget): The target system to attack. Must + support editable conversation history. attack_adversarial_config (AttackAdversarialConfig): Configuration for the adversarial component, including the adversarial chat target and optional system prompt path. attack_converter_config (Optional[AttackConverterConfig]): Configuration for attack converters, @@ -148,7 +160,7 @@ def __init__( application by role, message normalization, and non-chat target behavior. Raises: - ValueError: If objective_target is not a PromptChatTarget. + ValueError: If ``objective_target`` does not natively support editable history. """ # Initialize base class super().__init__(objective_target=objective_target, logger=logger, context_type=CrescendoAttackContext) @@ -257,17 +269,7 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None: Args: context (CrescendoAttackContext): Attack context with configuration - - Raises: - ValueError: If the objective target does not support multi-turn conversations. """ - if not self._objective_target.capabilities.supports_multi_turn: - raise ValueError( - "CrescendoAttack requires a multi-turn target. Crescendo fundamentally relies on " - "multi-turn conversation history to gradually escalate prompts. " - "Use RedTeamingAttack or TreeOfAttacksWithPruning instead." - ) - # Ensure the context has a session context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index a9d4b75adc..3239c5568e 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -29,6 +29,8 @@ ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import Scorer if TYPE_CHECKING: @@ -123,6 +125,15 @@ class MultiPromptSendingAttack(MultiTurnAttackStrategy[MultiTurnAttackContext[An and multiple scorer types for comprehensive evaluation. """ + # Sending a sequence of distinct prompts depends on the target maintaining + # conversation state between them. History-squash adaptation would collapse + # them into one message and silently break the attack's sequencing + # semantics. Declare MULTI_TURN as ``native_required`` so adaptation is + # rejected at construction time. + TARGET_REQUIREMENTS = TargetRequirements( + native_required=frozenset({CapabilityName.MULTI_TURN}), + ) + @apply_defaults def __init__( self, @@ -204,16 +215,7 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: Args: context (MultiTurnAttackContext): The attack context containing attack parameters. - - Raises: - ValueError: If the objective target does not support multi-turn conversations. """ - if not self._objective_target.capabilities.supports_multi_turn: - raise ValueError( - "MultiPromptSendingAttack requires a multi-turn target. " - "The objective target does not support multi-turn conversations." - ) - # Ensure the context has a session (like red_teaming.py does) context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 04c8084f7b..007845d864 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -18,6 +18,7 @@ ) from pyrit.memory import CentralMemory from pyrit.models import ConversationReference, ConversationType +from pyrit.prompt_target.common.target_capabilities import CapabilityName if TYPE_CHECKING: from pyrit.models import ( @@ -117,7 +118,7 @@ def _rotate_conversation_for_single_turn_target( Args: context: The current attack context. """ - if self._objective_target.capabilities.supports_multi_turn: + if self._objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): return if context.executed_turns == 0: diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 7ea7f927b7..cbd604ccb1 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -50,7 +50,8 @@ SeedPrompt, ) from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptChatTarget, PromptTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName from pyrit.score import ( FloatScaleThresholdScorer, Scorer, @@ -257,7 +258,7 @@ class _TreeOfAttacksNode: def __init__( self, *, - objective_target: PromptChatTarget, + objective_target: PromptTarget, adversarial_chat: PromptChatTarget, adversarial_chat_seed_prompt: SeedPrompt, adversarial_chat_prompt_template: SeedPrompt, @@ -279,7 +280,7 @@ def __init__( Initialize a tree node. Args: - objective_target (PromptChatTarget): The target to attack. + objective_target (PromptTarget): The target to attack. adversarial_chat (PromptChatTarget): The chat target for generating adversarial prompts. adversarial_chat_seed_prompt (SeedPrompt): The seed prompt for the first turn. adversarial_chat_prompt_template (SeedPrompt): The template for subsequent turns. @@ -780,7 +781,7 @@ def duplicate(self) -> "_TreeOfAttacksNode": # For single-turn targets, duplicate only the system messages (e.g., system prompt # from prepended conversation) so the target retains its configuration without # carrying over attack turn history that would cause validation errors. - if self._objective_target.capabilities.supports_multi_turn: + if self._objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): duplicate_node.objective_target_conversation_id = self._memory.duplicate_conversation( conversation_id=self.objective_target_conversation_id ) @@ -1254,7 +1255,7 @@ class TreeOfAttacksWithPruningAttack(AttackStrategy[TAPAttackContext, TAPAttackR def __init__( self, *, - objective_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] attack_adversarial_config: AttackAdversarialConfig, attack_converter_config: Optional[AttackConverterConfig] = None, attack_scoring_config: Optional[AttackScoringConfig] = None, @@ -1271,7 +1272,7 @@ def __init__( Initialize the Tree of Attacks with Pruning attack strategy. Args: - objective_target (PromptChatTarget): The target system to attack. + objective_target (PromptTarget): The target system to attack. attack_adversarial_config (AttackAdversarialConfig): Configuration for the adversarial chat component. attack_converter_config (Optional[AttackConverterConfig]): Configuration for attack converters. Defaults to None. @@ -1293,7 +1294,8 @@ def __init__( Raises: ValueError: If attack_scoring_config uses a non-FloatScaleThresholdScorer objective scorer, - if target is not PromptChatTarget, or if parameters are invalid. + if the adversarial target does not natively support the capabilities TAP needs, + or if parameters are invalid. """ # Validate tree parameters if tree_depth < 1: @@ -1322,8 +1324,19 @@ def __init__( # Initialize adversarial configuration self._adversarial_chat = attack_adversarial_config.target - if not isinstance(self._adversarial_chat, PromptChatTarget): - raise ValueError("The adversarial target must be a PromptChatTarget for TAP attack.") + # TAP sets a system prompt on the adversarial target and drives a + # multi-turn dialogue through it; both capabilities must be native. + adversarial_config = self._adversarial_chat.configuration + missing_native = [ + capability.value + for capability in (CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT) + if not adversarial_config.includes(capability=capability) + ] + if missing_native: + raise ValueError( + "TreeOfAttacksWithPruningAttack requires an adversarial target that natively supports: " + + ", ".join(missing_native) + ) # Load system prompts self._adversarial_chat_system_prompt_path = ( @@ -1857,7 +1870,7 @@ def _create_attack_node( generate adversarial prompts and evaluate responses. """ node = _TreeOfAttacksNode( - objective_target=cast("PromptChatTarget", self._objective_target), + objective_target=self._objective_target, adversarial_chat=self._adversarial_chat, adversarial_chat_seed_prompt=self._adversarial_chat_seed_prompt, adversarial_chat_system_seed_prompt=self._adversarial_chat_system_seed_prompt, diff --git a/pyrit/prompt_converter/denylist_converter.py b/pyrit/prompt_converter/denylist_converter.py index 46f427caef..7cbc6f5e9e 100644 --- a/pyrit/prompt_converter/denylist_converter.py +++ b/pyrit/prompt_converter/denylist_converter.py @@ -10,7 +10,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -19,14 +19,14 @@ class DenylistConverter(LLMGenericTextConverter): """ Replaces forbidden words or phrases in a prompt with synonyms using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, denylist: list[str] | None = None, ): @@ -34,7 +34,7 @@ def __init__( Initialize the converter with a target, an optional system prompt template, and a denylist. Args: - converter_target (PromptChatTarget): The target for the prompt conversion. + converter_target (PromptTarget): The target for the prompt conversion. Can be omitted if a default has been configured via PyRIT initialization. system_prompt_template (Optional[SeedPrompt]): The system prompt template to use for the conversion. If not provided, a default template will be used. diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index f56990247f..59808827d7 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -15,7 +15,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -27,12 +27,13 @@ class LLMGenericTextConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, user_prompt_template_with_objective: Optional[SeedPrompt] = None, **kwargs: Any, @@ -41,8 +42,10 @@ def __init__( Initialize the converter with a target and optional prompt templates. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. - Can be omitted if a default has been configured via PyRIT initialization. + converter_target (PromptTarget): The endpoint that converts the prompt. Must satisfy + ``CHAT_CONSUMER_REQUIREMENTS`` (system-prompt + multi-turn capabilities, possibly via + normalization-pipeline adaptation). Can be omitted if a default has been configured + via PyRIT initialization. system_prompt_template (SeedPrompt, Optional): The prompt template to set as the system prompt. user_prompt_template_with_objective (SeedPrompt, Optional): The prompt template to set as the user prompt. expects @@ -51,6 +54,7 @@ def __init__( Raises: ValueError: If converter_target is not provided and no default has been configured. """ + type(self).TARGET_REQUIREMENTS.validate(target=converter_target) self._converter_target = converter_target self._system_prompt_template = system_prompt_template self._prompt_kwargs = kwargs diff --git a/pyrit/prompt_converter/malicious_question_generator_converter.py b/pyrit/prompt_converter/malicious_question_generator_converter.py index 41a7848458..5725fff9c4 100644 --- a/pyrit/prompt_converter/malicious_question_generator_converter.py +++ b/pyrit/prompt_converter/malicious_question_generator_converter.py @@ -10,7 +10,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -19,21 +19,21 @@ class MaliciousQuestionGeneratorConverter(LLMGenericTextConverter): """ Generates malicious questions using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with a specific target and template. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. """ diff --git a/pyrit/prompt_converter/math_prompt_converter.py b/pyrit/prompt_converter/math_prompt_converter.py index fd6491bbc1..a3520b190c 100644 --- a/pyrit/prompt_converter/math_prompt_converter.py +++ b/pyrit/prompt_converter/math_prompt_converter.py @@ -10,7 +10,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -19,21 +19,21 @@ class MathPromptConverter(LLMGenericTextConverter): """ Converts natural language instructions into symbolic mathematics problems using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with a specific target and template. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. """ diff --git a/pyrit/prompt_converter/noise_converter.py b/pyrit/prompt_converter/noise_converter.py index 0d7bdf302f..86c5375773 100644 --- a/pyrit/prompt_converter/noise_converter.py +++ b/pyrit/prompt_converter/noise_converter.py @@ -11,7 +11,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -20,14 +20,14 @@ class NoiseConverter(LLMGenericTextConverter): """ Injects noise errors into a conversation using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] noise: Optional[str] = None, number_errors: int = 5, prompt_template: Optional[SeedPrompt] = None, @@ -36,7 +36,7 @@ def __init__( Initialize the converter with the specified parameters. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. noise (str): The noise to inject. Grammar error, delete random letter, insert random space, etc. number_errors (int): The number of errors to inject. diff --git a/pyrit/prompt_converter/persuasion_converter.py b/pyrit/prompt_converter/persuasion_converter.py index 11b6bd66e6..363405c8bd 100644 --- a/pyrit/prompt_converter/persuasion_converter.py +++ b/pyrit/prompt_converter/persuasion_converter.py @@ -21,7 +21,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -47,19 +47,20 @@ class PersuasionConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] persuasion_technique: str, ): """ Initialize the converter with the specified target and prompt template. Args: - converter_target (PromptChatTarget): The chat target used to perform rewriting on user prompts. + converter_target (PromptTarget): The chat target used to perform rewriting on user prompts. Can be omitted if a default has been configured via PyRIT initialization. persuasion_technique (str): Persuasion technique to be used by the converter, determines the system prompt to be used to generate new prompts. Must be one of "authority_endorsement", "evidence_based", @@ -69,6 +70,7 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the persuasion technique is not supported or does not exist. """ + type(self).TARGET_REQUIREMENTS.validate(target=converter_target) self.converter_target = converter_target try: diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 141076e701..b245a0c846 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -6,11 +6,12 @@ import inspect import re from dataclasses import dataclass -from typing import Any, Optional, Union, get_args +from typing import Any, ClassVar, Optional, Union, get_args from pyrit import prompt_converter from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.models import PromptDataType +from pyrit.prompt_target.common.target_requirements import TargetRequirements @dataclass @@ -48,6 +49,11 @@ class PromptConverter(Identifiable): #: Tuple of output modalities supported by this converter. Subclasses must override this. SUPPORTED_OUTPUT_TYPES: tuple[PromptDataType, ...] = () + #: Capability requirements placed on the converter's target (if any). + #: Subclasses that use a target should override this and call + #: ``type(self).TARGET_REQUIREMENTS.validate(target=converter_target)`` in ``__init__``. + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() + _identifier: Optional[ComponentIdentifier] = None def __init_subclass__(cls, **kwargs: object) -> None: diff --git a/pyrit/prompt_converter/random_translation_converter.py b/pyrit/prompt_converter/random_translation_converter.py index 74953c2603..7e11810323 100644 --- a/pyrit/prompt_converter/random_translation_converter.py +++ b/pyrit/prompt_converter/random_translation_converter.py @@ -13,7 +13,7 @@ from pyrit.prompt_converter.prompt_converter import ConverterResult from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy from pyrit.prompt_converter.word_level_converter import WordLevelConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -22,7 +22,7 @@ class RandomTranslationConverter(LLMGenericTextConverter, WordLevelConverter): """ Translates each individual word in a prompt to a random language using an LLM. - An existing ``PromptChatTarget`` is used to perform the translation (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the translation (like Azure OpenAI). """ SUPPORTED_INPUT_TYPES = ("text",) @@ -35,7 +35,7 @@ class RandomTranslationConverter(LLMGenericTextConverter, WordLevelConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, languages: Optional[list[str]] = None, word_selection_strategy: Optional[WordSelectionStrategy] = None, @@ -44,7 +44,7 @@ def __init__( Initialize the converter with a target, an optional system prompt template, and language options. Args: - converter_target (PromptChatTarget): The target for the prompt conversion. + converter_target (PromptTarget): The target for the prompt conversion. Can be omitted if a default has been configured via PyRIT initialization. system_prompt_template (Optional[SeedPrompt]): The system prompt template to use for the conversion. If not provided, a default template will be used. diff --git a/pyrit/prompt_converter/scientific_translation_converter.py b/pyrit/prompt_converter/scientific_translation_converter.py index 2a6c965996..bdc7987041 100644 --- a/pyrit/prompt_converter/scientific_translation_converter.py +++ b/pyrit/prompt_converter/scientific_translation_converter.py @@ -10,7 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ class ScientificTranslationConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] mode: str = "combined", prompt_template: Optional[SeedPrompt] = None, ) -> None: @@ -53,7 +53,7 @@ def __init__( Initialize the scientific translation converter. Args: - converter_target (PromptChatTarget): The LLM target to perform the conversion. + converter_target (PromptTarget): The LLM target to perform the conversion. mode (str): The translation mode to use. Built-in options are: - ``academic``: Use academic/homework style framing diff --git a/pyrit/prompt_converter/tense_converter.py b/pyrit/prompt_converter/tense_converter.py index 237a2934d5..eede7adef9 100644 --- a/pyrit/prompt_converter/tense_converter.py +++ b/pyrit/prompt_converter/tense_converter.py @@ -10,7 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -19,14 +19,14 @@ class TenseConverter(LLMGenericTextConverter): """ Converts a conversation to a different tense using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] tense: str, prompt_template: Optional[SeedPrompt] = None, ): @@ -34,7 +34,7 @@ def __init__( Initialize the converter with the target chat support, tense, and optional prompt template. Args: - converter_target (PromptChatTarget): The target chat support for the conversion which will translate. + converter_target (PromptTarget): The target chat support for the conversion which will translate. Can be omitted if a default has been configured via PyRIT initialization. tense (str): The tense the converter should convert the prompt to. E.g. past, present, future. prompt_template (SeedPrompt, Optional): The prompt template for the conversion. diff --git a/pyrit/prompt_converter/tone_converter.py b/pyrit/prompt_converter/tone_converter.py index a7b8e5a9f1..4a6d0e859e 100644 --- a/pyrit/prompt_converter/tone_converter.py +++ b/pyrit/prompt_converter/tone_converter.py @@ -10,7 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -19,14 +19,14 @@ class ToneConverter(LLMGenericTextConverter): """ Converts a conversation to a different tone using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] tone: str, prompt_template: Optional[SeedPrompt] = None, ): @@ -34,7 +34,7 @@ def __init__( Initialize the converter with the target chat support, tone, and optional prompt template. Args: - converter_target (PromptChatTarget): The target chat support for the conversion which will translate. + converter_target (PromptTarget): The target chat support for the conversion which will translate. Can be omitted if a default has been configured via PyRIT initialization. tone (str): The tone for the conversation. E.g. upset, sarcastic, indifferent, etc. prompt_template (SeedPrompt, Optional): The prompt template for the conversion. diff --git a/pyrit/prompt_converter/toxic_sentence_generator_converter.py b/pyrit/prompt_converter/toxic_sentence_generator_converter.py index d3390c6af7..636e50ad8d 100644 --- a/pyrit/prompt_converter/toxic_sentence_generator_converter.py +++ b/pyrit/prompt_converter/toxic_sentence_generator_converter.py @@ -14,7 +14,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ class ToxicSentenceGeneratorConverter(LLMGenericTextConverter): """ Generates toxic sentence starters using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). Based on Project Moonshot's attack module that generates toxic sentences to test LLM safety guardrails: @@ -34,14 +34,14 @@ class ToxicSentenceGeneratorConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with a specific target and template. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. If not provided, defaults to the ``toxic_sentence_generator.yaml``. diff --git a/pyrit/prompt_converter/translation_converter.py b/pyrit/prompt_converter/translation_converter.py index 911f72ab57..0228a0265c 100644 --- a/pyrit/prompt_converter/translation_converter.py +++ b/pyrit/prompt_converter/translation_converter.py @@ -24,7 +24,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -36,12 +36,13 @@ class TranslationConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] language: str, prompt_template: Optional[SeedPrompt] = None, max_retries: int = 3, @@ -51,7 +52,7 @@ def __init__( Initialize the converter with the target chat support, language, and optional prompt template. Args: - converter_target (PromptChatTarget): The target chat support for the conversion which will translate. + converter_target (PromptTarget): The target chat support for the conversion which will translate. Can be omitted if a default has been configured via PyRIT initialization. language (str): The language for the conversion. E.g. Spanish, French, leetspeak, etc. prompt_template (SeedPrompt, Optional): The prompt template for the conversion. @@ -62,6 +63,7 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the language is not provided. """ + type(self).TARGET_REQUIREMENTS.validate(target=converter_target) self.converter_target = converter_target # Retry strategy for the conversion diff --git a/pyrit/prompt_converter/variation_converter.py b/pyrit/prompt_converter/variation_converter.py index 328e463072..e7d35e76ee 100644 --- a/pyrit/prompt_converter/variation_converter.py +++ b/pyrit/prompt_converter/variation_converter.py @@ -23,7 +23,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -35,19 +35,20 @@ class VariationConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with the specified target and prompt template. Args: - converter_target (PromptChatTarget): The target to which the prompt will be sent for conversion. + converter_target (PromptTarget): The target to which the prompt will be sent for conversion. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt, optional): The template used for generating the system prompt. If not provided, a default template will be used. @@ -55,6 +56,7 @@ def __init__( Raises: ValueError: If converter_target is not provided and no default has been configured. """ + type(self).TARGET_REQUIREMENTS.validate(target=converter_target) self.converter_target = converter_target # set to default strategy if not provided diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index c71dca4089..db24087d22 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -20,7 +20,10 @@ UnsupportedCapabilityBehavior, ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration -from pyrit.prompt_target.common.target_requirements import TargetRequirements +from pyrit.prompt_target.common.target_requirements import ( + CHAT_CONSUMER_REQUIREMENTS, + TargetRequirements, +) from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget from pyrit.prompt_target.http_target.http_target import HTTPTarget @@ -51,6 +54,7 @@ "AzureMLChatTarget", "CapabilityName", "CapabilityHandlingPolicy", + "CHAT_CONSUMER_REQUIREMENTS", "CopilotType", "ConversationNormalizationPipeline", "GandalfLevel", diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index ce1f254678..6fe9c910fd 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -3,11 +3,10 @@ from typing import Optional -from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.prompt_target import PromptTarget -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration @@ -64,37 +63,6 @@ def __init__( custom_capabilities=custom_capabilities, ) - def set_system_prompt( - self, - *, - system_prompt: str, - conversation_id: str, - attack_identifier: Optional[ComponentIdentifier] = None, - labels: Optional[dict[str, str]] = None, - ) -> None: - """ - Set the system prompt for the prompt target. May be overridden by subclasses. - - Raises: - RuntimeError: If the conversation already exists. - """ - messages = self._memory.get_conversation(conversation_id=conversation_id) - - if messages: - raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") - - self._memory.add_message_to_memory( - request=MessagePiece( - role="system", - conversation_id=conversation_id, - original_value=system_prompt, - converted_value=system_prompt, - prompt_target_identifier=self.get_identifier(), - attack_identifier=attack_identifier, - labels=labels, - ).to_message() - ) - def is_response_format_json(self, message_piece: MessagePiece) -> bool: """ Check if the response format is JSON and ensure the target supports it. @@ -128,7 +96,7 @@ def _get_json_response_config(self, *, message_piece: MessagePiece) -> _JsonResp """ config = _JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata) - if config.enabled and not self.capabilities.supports_json_output: + if config.enabled and not self.configuration.includes(capability=CapabilityName.JSON_OUTPUT): target_name = self.get_identifier().class_name raise ValueError(f"This target {target_name} does not support JSON response format.") diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 53d7d2085a..5c9712d21b 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -8,8 +8,8 @@ from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.memory import CentralMemory, MemoryInterface -from pyrit.models import Message -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration, resolve_configuration_compat logger = logging.getLogger(__name__) @@ -178,7 +178,7 @@ def _validate_request(self, *, normalized_conversation: list[Message]) -> None: custom_configuration_message = ( "If your target does support this, set the custom_configuration parameter accordingly." ) - if not self.capabilities.supports_multi_message_pieces and n_pieces != 1: + if not self.configuration.includes(capability=CapabilityName.MULTI_MESSAGE_PIECES) and n_pieces != 1: raise ValueError( f"This target only supports a single message piece. Received: {n_pieces} pieces. " f"{custom_configuration_message}" @@ -194,7 +194,7 @@ def _validate_request(self, *, normalized_conversation: list[Message]) -> None: f"{custom_configuration_message}" ) - if not self.capabilities.supports_multi_turn and len(normalized_conversation) > 1: + if not self.configuration.includes(capability=CapabilityName.MULTI_TURN) and len(normalized_conversation) > 1: raise ValueError(f"This target only supports a single turn conversation. {custom_configuration_message}") async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]: @@ -270,6 +270,56 @@ def set_model_name(self, *, model_name: str) -> None: """ self._model_name = model_name + def set_system_prompt( + self, + *, + system_prompt: str, + conversation_id: str, + attack_identifier: ComponentIdentifier | None = None, + labels: dict[str, str] | None = None, + ) -> None: + """ + Inject a system prompt into memory for the given conversation. + + Writes a ``system``-role message so the target's normalization pipeline + (or the target itself, when it natively supports system prompts) will + pick it up on the next ``send_prompt_async`` call. + + If the target does not natively support system prompts, whether this + call is ultimately honored depends on the target's + :class:`CapabilityHandlingPolicy`: + + * ``ADAPT`` — the normalization pipeline (e.g. system squash) will + fold the system message into user content on the wire. + * ``RAISE`` — the first send after the system prompt is set will + raise, because the pipeline cannot adapt the missing capability. + + Args: + system_prompt (str): The system prompt text to set. + conversation_id (str): The conversation id to attach the prompt to. + attack_identifier (ComponentIdentifier | None): Optional attack identifier. + labels (dict[str, str] | None): Optional labels. + + Raises: + RuntimeError: If the conversation already has messages. + """ + messages = self._memory.get_conversation(conversation_id=conversation_id) + + if messages: + raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") + + self._memory.add_message_to_memory( + request=MessagePiece( + role="system", + conversation_id=conversation_id, + original_value=system_prompt, + converted_value=system_prompt, + prompt_target_identifier=self.get_identifier(), + attack_identifier=attack_identifier, + labels=labels, + ).to_message() + ) + def dispose_db_engine(self) -> None: """ Dispose database engine to release database connections and resources. diff --git a/pyrit/prompt_target/common/target_requirements.py b/pyrit/prompt_target/common/target_requirements.py index 95182b47b5..f0909dcc9d 100644 --- a/pyrit/prompt_target/common/target_requirements.py +++ b/pyrit/prompt_target/common/target_requirements.py @@ -6,9 +6,10 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +from pyrit.prompt_target.common.target_capabilities import CapabilityName + if TYPE_CHECKING: - from pyrit.prompt_target.common.target_capabilities import CapabilityName - from pyrit.prompt_target.common.target_configuration import TargetConfiguration + from pyrit.prompt_target.common.prompt_target import PromptTarget @dataclass(frozen=True) @@ -17,38 +18,51 @@ class TargetRequirements: Declarative description of what a consumer (attack, converter, scorer) requires from a target. - Consumers define their requirements once and validate them against a - ``TargetConfiguration`` at construction time. This replaces ad-hoc - ``isinstance`` checks and scattered capability branching. + The single source of truth for capability names is the + :class:`CapabilityName` enum; this class is simply a typed wrapper + around the set of capabilities a consumer needs. + + Two tiers of requirement are supported: + + * ``required`` \u2014 satisfied either by native support on the target or + by an ``ADAPT`` entry in the target's + :class:`CapabilityHandlingPolicy`. Use this when the consumer only + needs the behavior to appear on the wire. + * ``native_required`` \u2014 must be natively supported. Adaptation is + rejected. Use this when adaptation would silently change the + consumer's semantics (e.g. an attack that depends on the target + remembering prior turns, where history-squash normalization would + collapse the conversation into a single prompt). """ - # The set of capabilities the consumer requires. - required_capabilities: frozenset[CapabilityName] = field(default_factory=frozenset) + required: frozenset[CapabilityName] = field(default_factory=frozenset) + native_required: frozenset[CapabilityName] = field(default_factory=frozenset) - def validate(self, *, configuration: TargetConfiguration) -> None: + def validate(self, *, target: PromptTarget) -> None: """ - Validate that the target configuration can satisfy all requirements. - - Iterates over every required capability and delegates to - ``TargetConfiguration.ensure_can_handle``, which checks native support - first and then consults the handling policy. All violations are - collected and reported in a single ``ValueError``. + Validate that ``target`` can satisfy every declared requirement. Args: - configuration (TargetConfiguration): The target configuration to validate against. + target (PromptTarget): The target to validate against. Raises: - ValueError: If any required capability is missing and the policy - does not allow adaptation. + ValueError: If any ``native_required`` capability is not natively + supported, or if any ``required`` capability is not supported + natively and has no ``ADAPT`` entry in the target's policy. """ - errors: list[str] = [] - for capability in sorted(self.required_capabilities, key=lambda c: c.value): - try: - configuration.ensure_can_handle(capability=capability) - except ValueError as exc: - errors.append(str(exc)) - if errors: - raise ValueError( - f"Target does not satisfy {len(errors)} required capability(ies):\n" - + "\n".join(f" - {e}" for e in errors) - ) + for capability in self.native_required: + if not target.configuration.includes(capability=capability): + raise ValueError( + f"Target must natively support '{capability.value}'; " + "adaptation is not acceptable for this consumer." + ) + for capability in self.required: + target.configuration.ensure_can_handle(capability=capability) + + +# Shared requirement used by scorers and converters that set a system prompt +# and drive a short multi-turn conversation. Adaptation is acceptable: the +# consumer only needs the behavior on the wire, not native support. +CHAT_CONSUMER_REQUIREMENTS = TargetRequirements( + required=frozenset({CapabilityName.EDITABLE_HISTORY, CapabilityName.MULTI_TURN}), +) diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index db059d6807..5e45136043 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -72,6 +72,7 @@ class OpenAIChatTarget(OpenAITarget, PromptChatTarget): supports_json_output=True, supports_multi_message_pieces=True, supports_system_prompt=True, + supports_editable_history=True, input_modalities=frozenset( {frozenset({"text"}), frozenset({"image_path"}), frozenset({"text", "image_path"})} ), diff --git a/pyrit/scenario/scenarios/airt/psychosocial.py b/pyrit/scenario/scenarios/airt/psychosocial.py index b8963b9264..d4eba0303f 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial.py +++ b/pyrit/scenario/scenarios/airt/psychosocial.py @@ -28,6 +28,7 @@ PromptConverterConfiguration, ) from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.dataset_configuration import DatasetConfiguration @@ -421,9 +422,10 @@ def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScore async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: if self._objective_target is None: raise ValueError("objective_target must be set before creating attacks") - if not isinstance(self._objective_target, PromptChatTarget): + if not self._objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): raise TypeError( - f"PsychosocialHarmsScenario requires a PromptChatTarget, got {type(self._objective_target).__name__}" + f"PsychosocialHarmsScenario requires a target that natively supports " + f"multi-turn conversations, got {type(self._objective_target).__name__}." ) resolved = self._resolve_seed_groups() self._seed_groups = resolved.seed_groups diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index a117034b3b..e97f547a71 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -7,7 +7,7 @@ from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.identifiers import ComponentIdentifier from pyrit.models import PromptDataType, Score, UnvalidatedScore -from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.score.scorer import Scorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -70,7 +70,7 @@ def get_scorer_metrics(self) -> Optional["HarmScorerMetrics"]: async def _score_value_with_llm( self, *, - prompt_target: PromptChatTarget, + prompt_target: PromptTarget, system_prompt: str, message_value: str, message_data_type: PromptDataType, diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index 45c64dab00..f2245c3bee 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -9,7 +9,7 @@ from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -21,11 +21,12 @@ class InsecureCodeScorer(FloatScaleScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, system_prompt_path: Optional[Union[str, Path]] = None, validator: Optional[ScorerPromptValidator] = None, ): @@ -33,13 +34,14 @@ def __init__( Initialize the Insecure Code Scorer. Args: - chat_target (PromptChatTarget): The target to use for scoring code security. + chat_target (PromptTarget): The target to use for scoring code security. system_prompt_path (Optional[Union[str, Path]]): Path to the YAML file containing the system prompt. Defaults to the default insecure code scoring prompt if not provided. validator (Optional[ScorerPromptValidator]): Custom validator for the scorer. Defaults to None. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target if not system_prompt_path: diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index ae9e0acc4b..9750a936a3 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -5,13 +5,14 @@ from typing import TYPE_CHECKING, Optional +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator if TYPE_CHECKING: from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, UnvalidatedScore - from pyrit.prompt_target import PromptChatTarget + from pyrit.prompt_target import PromptTarget class SelfAskGeneralFloatScaleScorer(FloatScaleScorer): @@ -24,11 +25,12 @@ class SelfAskGeneralFloatScaleScorer(FloatScaleScorer): supported_data_types=["text"], is_objective_required=True, ) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, system_prompt_format_string: str, prompt_format_string: Optional[str] = None, category: Optional[str] = None, @@ -52,7 +54,9 @@ def __init__( in the response, the provided `category` argument will be applied. Args: - chat_target (PromptChatTarget): The chat target used to score. + chat_target (PromptTarget): The chat target used to score. Must satisfy + ``CHAT_CONSUMER_REQUIREMENTS`` (system-prompt + multi-turn capabilities, + possibly via normalization-pipeline adaptation). system_prompt_format_string (str): System prompt template with placeholders for objective, prompt, and message_piece. prompt_format_string (Optional[str]): User prompt template with the same placeholders. @@ -72,6 +76,7 @@ def __init__( ValueError: If min_value is greater than max_value. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target if not system_prompt_format_string: raise ValueError("system_prompt_format_string must be provided and non-empty.") diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index c6762089b6..3e019a01b4 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -12,7 +12,7 @@ from pyrit.common.path import HARM_DEFINITION_PATH, SCORER_LIKERT_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -173,11 +173,12 @@ class SelfAskLikertScorer(FloatScaleScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, likert_scale: Optional[LikertScalePaths] = None, custom_likert_path: Optional[Path] = None, custom_system_prompt_path: Optional[Path] = None, @@ -187,7 +188,7 @@ def __init__( Initialize the SelfAskLikertScorer. Args: - chat_target (PromptChatTarget): The chat target to use for scoring. + chat_target (PromptTarget): The chat target to use for scoring. likert_scale (Optional[LikertScalePaths]): The Likert scale configuration to use for scoring. custom_likert_path (Optional[Path]): Path to a custom YAML file containing the Likert scale definition. This allows users to use their own Likert scales without modifying the code, as long as @@ -203,6 +204,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target self._likert_scale = likert_scale diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 4bf0dc2dee..636faf6bc0 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -11,7 +11,7 @@ from pyrit.common.path import SCORER_SCALES_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -39,11 +39,12 @@ class SystemPaths(enum.Enum): supported_data_types=["text"], is_objective_required=True, ) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, scale_arguments_path: Optional[Union[Path, str]] = None, system_prompt_path: Optional[Union[Path, str]] = None, validator: Optional[ScorerPromptValidator] = None, @@ -52,7 +53,7 @@ def __init__( Initialize the SelfAskScaleScorer. Args: - chat_target (PromptChatTarget): The chat target to use for scoring. + chat_target (PromptTarget): The chat target to use for scoring. scale_arguments_path (Optional[Union[Path, str]]): Path to the YAML file containing scale definitions. Defaults to TREE_OF_ATTACKS_SCALE if not provided. system_prompt_path (Optional[Union[Path, str]]): Path to the YAML file containing the system prompt. @@ -61,6 +62,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target if not system_prompt_path: diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 11308edb64..d2bc150791 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -12,6 +12,7 @@ from typing import ( TYPE_CHECKING, Any, + ClassVar, Optional, Union, cast, @@ -35,11 +36,12 @@ UnvalidatedScore, ) from pyrit.prompt_target.batch_helper import batch_task_async +from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from collections.abc import Sequence - from pyrit.prompt_target import PromptChatTarget, PromptTarget + from pyrit.prompt_target import PromptTarget from pyrit.score.scorer_evaluation.metrics_type import RegistryUpdateBehavior from pyrit.score.scorer_evaluation.scorer_evaluator import ( ScorerEvalDatasetFiles, @@ -59,6 +61,11 @@ class Scorer(Identifiable, abc.ABC): # Specifies glob patterns for datasets and a result file name. evaluation_file_mapping: Optional[ScorerEvalDatasetFiles] = None + #: Capability requirements placed on the scorer's chat target (if any). + #: Subclasses that use a chat target should override this and call + #: ``type(self).TARGET_REQUIREMENTS.validate(target=chat_target)`` in ``__init__``. + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() + _identifier: Optional[ComponentIdentifier] = None def __init__(self, *, validator: ScorerPromptValidator): @@ -494,7 +501,7 @@ def scale_value_float(self, value: float, min_value: float, max_value: float) -> async def _score_value_with_llm( self, *, - prompt_target: PromptChatTarget, + prompt_target: PromptTarget, system_prompt: str, message_value: str, message_data_type: PromptDataType, @@ -516,7 +523,7 @@ async def _score_value_with_llm( description fields. Args: - prompt_target (PromptChatTarget): The target LLM to send the message to. + prompt_target (PromptTarget): The target LLM to send the message to. system_prompt (str): The system-level prompt that guides the behavior of the target LLM. message_value (str): The actual value or content to be scored by the LLM (e.g., text, image path, audio path). diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index 2aab7c264e..a3b838f190 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -11,7 +11,7 @@ from pyrit.exceptions import PyritException, pyrit_target_retry from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, MessagePiece, Score -from pyrit.prompt_target import GandalfLevel, PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, GandalfLevel, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -30,12 +30,13 @@ class GandalfScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, level: GandalfLevel, - chat_target: PromptChatTarget, + chat_target: PromptTarget, validator: Optional[ScorerPromptValidator] = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: @@ -44,13 +45,14 @@ def __init__( Args: level (GandalfLevel): The Gandalf challenge level to score against. - chat_target (PromptChatTarget): The chat target used for password extraction. + chat_target (PromptTarget): The chat target used for password extraction. validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to text data type validator. score_aggregator (TrueFalseAggregatorFunc): Aggregator for combining scores. Defaults to TrueFalseScoreAggregator.OR. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target self._defender = level.value self._endpoint = "https://gandalf-api.lakera.ai/api/guess-password" diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index 7102ba3af6..c0aa35ecee 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -11,7 +11,7 @@ from pyrit.common.path import SCORER_CONTENT_CLASSIFIERS_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -37,11 +37,12 @@ class SelfAskCategoryScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator() + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, content_classifier_path: Union[str, Path], score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, validator: Optional[ScorerPromptValidator] = None, @@ -50,7 +51,7 @@ def __init__( Initialize a new instance of the SelfAskCategoryScorer class. Args: - chat_target (PromptChatTarget): The chat target to interact with. + chat_target (PromptTarget): The chat target to interact with. content_classifier_path (Union[str, Path]): The path to the classifier YAML file. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. @@ -58,6 +59,7 @@ def __init__( """ super().__init__(score_aggregator=score_aggregator, validator=validator or self._DEFAULT_VALIDATOR) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target content_classifier_path = verify_and_resolve_path(content_classifier_path) diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 44bb362748..ffd34c4dd2 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -15,7 +16,7 @@ if TYPE_CHECKING: from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, UnvalidatedScore - from pyrit.prompt_target import PromptChatTarget + from pyrit.prompt_target import PromptTarget class SelfAskGeneralTrueFalseScorer(TrueFalseScorer): @@ -28,11 +29,12 @@ class SelfAskGeneralTrueFalseScorer(TrueFalseScorer): supported_data_types=["text"], is_objective_required=False, ) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, system_prompt_format_string: str, prompt_format_string: Optional[str] = None, category: Optional[str] = None, @@ -55,7 +57,9 @@ def __init__( in the response, the provided `category` argument will be applied. Args: - chat_target (PromptChatTarget): The chat target used to score. + chat_target (PromptTarget): The chat target used to score. Must satisfy + ``CHAT_CONSUMER_REQUIREMENTS`` (system-prompt + multi-turn capabilities, + possibly via normalization-pipeline adaptation). system_prompt_format_string (str): System prompt template with placeholders for objective, task (alias of objective), prompt, and message_piece. prompt_format_string (Optional[str]): User prompt template with the same placeholders. @@ -74,6 +78,7 @@ def __init__( ValueError: If system_prompt_format_string is not provided or empty. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target if not system_prompt_format_string: raise ValueError("system_prompt_format_string must be provided and non-empty.") diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index bf1c017dde..b52ed9fa79 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -18,7 +18,7 @@ import pathlib from pyrit.models import MessagePiece, Score, UnvalidatedScore - from pyrit.prompt_target import PromptChatTarget + from pyrit.prompt_target import PromptTarget class SelfAskQuestionAnswerScorer(SelfAskTrueFalseScorer): @@ -37,7 +37,7 @@ class SelfAskQuestionAnswerScorer(SelfAskTrueFalseScorer): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, true_false_question_path: Optional[pathlib.Path] = None, validator: Optional[ScorerPromptValidator] = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, @@ -46,7 +46,9 @@ def __init__( Initialize the SelfAskQuestionAnswerScorer object. Args: - chat_target (PromptChatTarget): The chat target to use for the scorer. + chat_target (PromptTarget): The chat target to use for the scorer. Must satisfy + ``CHAT_CONSUMER_REQUIREMENTS`` (system-prompt + multi-turn capabilities, + possibly via normalization-pipeline adaptation). true_false_question_path (Optional[pathlib.Path]): The path to the true/false question file. Defaults to None, which uses the default question_answering.yaml file. validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index cf9b30f1d8..4379dd76ed 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -8,7 +8,7 @@ from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -64,11 +64,12 @@ class SelfAskRefusalScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator() + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, refusal_system_prompt_path: Union[RefusalScorerPaths, Path, str] = RefusalScorerPaths.OBJECTIVE_STRICT, prompt_format_string: Optional[str] = None, validator: Optional[ScorerPromptValidator] = None, @@ -78,7 +79,7 @@ def __init__( Initialize the SelfAskRefusalScorer. Args: - chat_target (PromptChatTarget): The endpoint that will be used to score the prompt. + chat_target (PromptTarget): The endpoint that will be used to score the prompt. refusal_system_prompt_path (Union[RefusalScorerPaths, Path, str]): The path to the system prompt to use for refusal detection. Can be a RefusalScorerPaths enum value, a Path, or a string path. Defaults to RefusalScorerPaths.OBJECTIVE_STRICT. @@ -102,6 +103,7 @@ def __init__( super().__init__(score_aggregator=score_aggregator, validator=validator or self._DEFAULT_VALIDATOR) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target # Resolve the system prompt path diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index d79060fcb4..6685037a7b 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -12,7 +12,7 @@ from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -93,11 +93,12 @@ class SelfAskTrueFalseScorer(TrueFalseScorer): _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator( supported_data_types=["text", "image_path"], ) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, true_false_question_path: Optional[Union[str, Path]] = None, true_false_question: Optional[TrueFalseQuestion] = None, true_false_system_prompt_path: Optional[Union[str, Path]] = None, @@ -108,7 +109,7 @@ def __init__( Initialize the SelfAskTrueFalseScorer. Args: - chat_target (PromptChatTarget): The chat target to interact with. + chat_target (PromptTarget): The chat target to interact with. true_false_question_path (Optional[Union[str, Path]]): The path to the true/false question file. true_false_question (Optional[TrueFalseQuestion]): The true/false question object. true_false_system_prompt_path (Optional[Union[str, Path]]): The path to the system prompt file. @@ -122,6 +123,7 @@ def __init__( """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) self._prompt_target = chat_target if true_false_question_path and true_false_question: diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index c86e741e9c..a87948c330 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -94,7 +94,13 @@ def mock_chat_target() -> MagicMock: @pytest.fixture def mock_prompt_target() -> MagicMock: """Create a mock prompt target (non-chat) for testing.""" + from pyrit.prompt_target.common.target_capabilities import TargetCapabilities + from pyrit.prompt_target.common.target_configuration import TargetConfiguration + target = MagicMock(spec=PromptTarget) + target.configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False), + ) target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py index 3baeaf463c..9de79ab384 100644 --- a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py @@ -27,6 +27,7 @@ def _make_strategy(*, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn + target.configuration.includes.return_value = supports_multi_turn target.get_identifier.return_value = MagicMock() with patch.multiple( @@ -376,6 +377,7 @@ def _make_tap_node(self, *, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn + target.configuration.includes.return_value = supports_multi_turn target.get_identifier.return_value = MagicMock() adversarial_chat = MagicMock() @@ -684,8 +686,14 @@ class TestValueErrorGuards: """Test that incompatible attacks raise ValueError for single-turn targets.""" def _make_single_turn_target(self): + from pyrit.prompt_target.common.target_capabilities import TargetCapabilities + from pyrit.prompt_target.common.target_configuration import TargetConfiguration + target = MagicMock() target.capabilities.supports_multi_turn = False + target.configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True), + ) target.get_identifier.return_value = MagicMock() return target @@ -706,52 +714,34 @@ def _make_scoring_config(self): @pytest.mark.asyncio async def test_crescendo_raises_for_single_turn_target(self): - from pyrit.executor.attack.multi_turn.crescendo import CrescendoAttack, CrescendoAttackContext + from pyrit.executor.attack.multi_turn.crescendo import CrescendoAttack target = self._make_single_turn_target() - attack = CrescendoAttack( - objective_target=target, - attack_adversarial_config=self._make_adversarial_config(), - attack_scoring_config=self._make_scoring_config(), - ) - - context = CrescendoAttackContext( - params=AttackParameters(objective="Test"), - ) - with pytest.raises(ValueError, match="CrescendoAttack requires a multi-turn target"): - await attack._setup_async(context=context) + with pytest.raises(ValueError, match="supports_multi_turn"): + CrescendoAttack( + objective_target=target, + attack_adversarial_config=self._make_adversarial_config(), + attack_scoring_config=self._make_scoring_config(), + ) @pytest.mark.asyncio async def test_multi_prompt_sending_raises_for_single_turn_target(self): from pyrit.executor.attack.multi_turn.multi_prompt_sending import MultiPromptSendingAttack target = self._make_single_turn_target() - attack = MultiPromptSendingAttack(objective_target=target) - - context = MultiTurnAttackContext( - params=AttackParameters(objective="Test"), - ) - with pytest.raises(ValueError, match="MultiPromptSendingAttack requires a multi-turn target"): - await attack._setup_async(context=context) + with pytest.raises(ValueError, match="supports_multi_turn"): + MultiPromptSendingAttack(objective_target=target) @pytest.mark.asyncio async def test_chunked_request_raises_for_single_turn_target(self): - from pyrit.executor.attack.multi_turn.chunked_request import ( - ChunkedRequestAttack, - ChunkedRequestAttackContext, - ) + from pyrit.executor.attack.multi_turn.chunked_request import ChunkedRequestAttack target = self._make_single_turn_target() - attack = ChunkedRequestAttack(objective_target=target) - - context = ChunkedRequestAttackContext( - params=AttackParameters(objective="Test"), - ) - with pytest.raises(ValueError, match="ChunkedRequestAttack requires a multi-turn target"): - await attack._setup_async(context=context) + with pytest.raises(ValueError, match="supports_multi_turn"): + ChunkedRequestAttack(objective_target=target) @pytest.mark.usefixtures("patch_central_database") @@ -764,6 +754,7 @@ def _make_tap_node(self, *, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn + target.configuration.includes.return_value = supports_multi_turn target.get_identifier.return_value = MagicMock() adversarial_chat = MagicMock() diff --git a/tests/unit/prompt_target/target/test_target_requirements.py b/tests/unit/prompt_target/target/test_target_requirements.py index 002ccf086c..0774afdda2 100644 --- a/tests/unit/prompt_target/target/test_target_requirements.py +++ b/tests/unit/prompt_target/target/test_target_requirements.py @@ -1,131 +1,116 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from unittest.mock import MagicMock + import pytest +from pyrit.prompt_target import ( + CHAT_CONSUMER_REQUIREMENTS, + CapabilityName, + TargetRequirements, +) from pyrit.prompt_target.common.target_capabilities import ( CapabilityHandlingPolicy, - CapabilityName, TargetCapabilities, UnsupportedCapabilityBehavior, ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration -from pyrit.prompt_target.common.target_requirements import TargetRequirements - - -@pytest.fixture -def adapt_all_policy(): - return CapabilityHandlingPolicy( - behaviors={ - CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, - CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, - CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, - } - ) -# --------------------------------------------------------------------------- -# Construction -# --------------------------------------------------------------------------- +def _make_target(*, configuration: TargetConfiguration) -> MagicMock: + target = MagicMock() + target.configuration = configuration + return target -def test_init_default_has_empty_capabilities(): - reqs = TargetRequirements() - assert reqs.required_capabilities == frozenset() +def test_default_requirements_require_nothing(): + assert TargetRequirements().required == frozenset() -def test_init_with_capabilities(): +def test_construction_from_frozenset(): reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) + required=frozenset({CapabilityName.MULTI_TURN, CapabilityName.JSON_OUTPUT}), ) - assert CapabilityName.MULTI_TURN in reqs.required_capabilities - assert CapabilityName.SYSTEM_PROMPT in reqs.required_capabilities + assert reqs.required == {CapabilityName.MULTI_TURN, CapabilityName.JSON_OUTPUT} -# --------------------------------------------------------------------------- -# validate — all pass -# --------------------------------------------------------------------------- +def test_chat_consumer_requirements_shape(): + assert CHAT_CONSUMER_REQUIREMENTS.required == { + CapabilityName.SYSTEM_PROMPT, + CapabilityName.MULTI_TURN, + } -def test_validate_passes_when_target_supports_all_natively(): - caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) - ) - reqs.validate(configuration=config) +def test_requirements_are_frozen(): + reqs = TargetRequirements(required=frozenset({CapabilityName.MULTI_TURN})) + with pytest.raises(Exception): + reqs.required = frozenset() # type: ignore[misc] -def test_validate_passes_when_policy_is_adapt(adapt_all_policy): - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) +def test_validate_passes_on_native_support(): + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=True, + ), + ), ) - reqs.validate(configuration=config) - - -def test_validate_passes_with_empty_requirements(): - caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements() - reqs.validate(configuration=config) - -# --------------------------------------------------------------------------- -# validate — failures -# --------------------------------------------------------------------------- - - -def test_validate_raises_when_capability_missing_and_no_policy(): - # EDITABLE_HISTORY has no normalizer and no handling policy — validate raises. - caps = TargetCapabilities(supports_editable_history=False, supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements(required_capabilities=frozenset({CapabilityName.EDITABLE_HISTORY})) - with pytest.raises(ValueError, match="supports_editable_history"): - reqs.validate(configuration=config) - - -def test_validate_raises_when_capability_missing_and_policy_raise(adapt_all_policy): - # json_output is missing and the policy is RAISE — validate raises. - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False, supports_json_output=False) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - reqs = TargetRequirements(required_capabilities=frozenset({CapabilityName.JSON_OUTPUT})) - with pytest.raises(ValueError, match="supports_json_output"): - reqs.validate(configuration=config) - - -def test_validate_collects_all_unsatisfied_capabilities(adapt_all_policy): - """When multiple capabilities are missing, validate reports all violations.""" - caps = TargetCapabilities( - supports_multi_turn=False, - supports_system_prompt=False, - supports_json_output=False, - supports_editable_history=False, + CHAT_CONSUMER_REQUIREMENTS.validate(target=target) + + +def test_validate_passes_when_policy_is_adapt(): + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=False, + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + }, + ), + ), ) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - # json_output => RAISE, editable_history => no policy (raises) - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.JSON_OUTPUT, CapabilityName.EDITABLE_HISTORY}) - ) - with pytest.raises(ValueError, match="2 required capability") as exc_info: - reqs.validate(configuration=config) - assert "supports_json_output" in str(exc_info.value) - assert "supports_editable_history" in str(exc_info.value) + CHAT_CONSUMER_REQUIREMENTS.validate(target=target) + + +def test_validate_raises_when_capability_neither_native_nor_adapt(): + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + }, + ), + ), + ) -def test_validate_mixed_adapt_and_raise(adapt_all_policy): - """One capability adapts but another raises — validate should raise.""" - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False, supports_json_output=False) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - # multi_turn and system_prompt => ADAPT (OK), json_output => RAISE (fail) - reqs = TargetRequirements( - required_capabilities=frozenset( - {CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT, CapabilityName.JSON_OUTPUT} - ) + with pytest.raises(ValueError, match=CapabilityName.SYSTEM_PROMPT.value): + CHAT_CONSUMER_REQUIREMENTS.validate(target=target) + + +def test_validate_empty_required_always_passes(): + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities(), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + }, + ), + ), ) - with pytest.raises(ValueError, match="supports_json_output"): - reqs.validate(configuration=config) + + TargetRequirements().validate(target=target)