Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Dict, List, Optional

from pyrit.identifiers import ScorerIdentifier
from pyrit.memory import CentralMemory
from pyrit.models import Score, UnvalidatedScore, MessagePiece, Message
from pyrit.score import ScorerPromptValidator
from pyrit.score.true_false.true_false_scorer import TrueFalseScorer
Expand Down Expand Up @@ -231,6 +232,9 @@ async def _score_piece_async(
f"Score will be treated as undetermined."
)

# Extract token usage from eval_result
token_usage = self._extract_token_usage(eval_result, metric_name)

if raw_score is None:
self.logger.warning(f"No matching result found for metric '{metric_name}' in evaluation response.")
raw_score = 0
Expand All @@ -255,18 +259,25 @@ async def _score_piece_async(
score_type="true_false",
score_category=[self.risk_category.value],
score_rationale=reason,
score_metadata={
"raw_score": raw_score,
"threshold": threshold,
"result_label": result_label,
"risk_category": self.risk_category.value,
"metric_name": metric_name,
},
score_metadata=self._build_score_metadata(
raw_score=raw_score,
threshold=threshold,
result_label=result_label,
metric_name=metric_name,
token_usage=token_usage,
),
scorer_class_identifier=self.get_identifier(),
message_piece_id=request_response.id,
objective=task or "",
)

# Save score to PyRIT memory so it's available via attack_result.last_score
try:
memory = CentralMemory.get_memory_instance()
memory.add_scores_to_memory(scores=[score])
except Exception as mem_err:
self.logger.debug(f"Could not save score to memory: {mem_err}")

return [score]

except Exception as e:
Expand Down Expand Up @@ -349,6 +360,99 @@ def _get_context_for_piece(self, piece: MessagePiece) -> str:

return ""

def _extract_token_usage(self, eval_result: Any, metric_name: str) -> Dict[str, Any]:
"""Extract token usage metrics from the RAI service evaluation result.

Checks sample.usage first, then falls back to result-level properties.

:param eval_result: The evaluation result from RAI service
:type eval_result: Any
:param metric_name: The metric name used for the evaluation
:type metric_name: str
:return: Dictionary with token usage metrics (may be empty)
:rtype: Dict[str, Any]
"""
token_usage: Dict[str, Any] = {}

# Try sample.usage (EvalRunOutputItem structure)
sample = None
if hasattr(eval_result, "sample"):
sample = eval_result.sample
elif isinstance(eval_result, dict):
sample = eval_result.get("sample")

if sample:
usage = sample.get("usage") if isinstance(sample, dict) else getattr(sample, "usage", None)
if usage:
usage_dict = usage if isinstance(usage, dict) else getattr(usage, "__dict__", {})
for key in ("prompt_tokens", "completion_tokens", "total_tokens", "cached_tokens"):
if key in usage_dict and usage_dict[key] is not None:
token_usage[key] = usage_dict[key]

# Fallback: check result-level properties.metrics
if not token_usage:
results = None
if hasattr(eval_result, "results"):
results = eval_result.results
elif isinstance(eval_result, dict):
results = eval_result.get("results")

if results:
# Build a set of metric aliases to match against, to support
# both canonical and legacy metric names.
metric_aliases = {metric_name}
legacy_name = _SYNC_TO_LEGACY_METRIC_NAMES.get(metric_name)
if legacy_name:
metric_aliases.add(legacy_name)
sync_name = _LEGACY_TO_SYNC_METRIC_NAMES.get(metric_name)
if sync_name:
metric_aliases.add(sync_name)

for result_item in results or []:
result_dict = result_item if isinstance(result_item, dict) else getattr(result_item, "__dict__", {})
result_name = result_dict.get("name") or result_dict.get("metric")
if result_name in metric_aliases:
props = result_dict.get("properties", {})
if isinstance(props, dict):
metrics = props.get("metrics", {})
if isinstance(metrics, dict):
for key in ("prompt_tokens", "completion_tokens", "total_tokens", "cached_tokens"):
if key in metrics and metrics[key] is not None:
token_usage[key] = metrics[key]
break

return token_usage

def _build_score_metadata(
self,
*,
raw_score: Any,
threshold: Any,
result_label: str,
metric_name: str,
token_usage: Dict[str, Any],
) -> Dict[str, Any]:
"""Build the score_metadata dictionary for a Score object.

:param raw_score: The raw numeric score from RAI service
:param threshold: The threshold value
:param result_label: The result label string
:param metric_name: The metric name
:param token_usage: Token usage metrics dict (may be empty)
:return: Score metadata dictionary
:rtype: Dict[str, Any]
"""
metadata: Dict[str, Any] = {
"raw_score": raw_score,
"threshold": threshold,
"result_label": result_label,
"risk_category": self.risk_category.value,
"metric_name": metric_name,
}
if token_usage:
metadata["token_usage"] = token_usage
return metadata

def validate(
self,
request_response: MessagePiece,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def to_red_team_result(
# Determine attack success based on evaluation results if available
attack_success = None
risk_assessment = {}
scorer_token_usage = None

eval_row = None

Expand Down Expand Up @@ -291,12 +292,22 @@ def to_red_team_result(
score_data = conv_data.get("score", {})
if score_data and isinstance(score_data, dict):
score_metadata = score_data.get("metadata", {})
raw_score = score_metadata.get("raw_score")
if raw_score is not None:
risk_assessment[risk_category] = {
"severity_label": get_harm_severity_level(raw_score),
"reason": score_data.get("rationale", ""),
}
# Handle string metadata (e.g. from PyRIT serialization)
if isinstance(score_metadata, str):
try:
score_metadata = json.loads(score_metadata)
except (json.JSONDecodeError, TypeError):
score_metadata = {}
if isinstance(score_metadata, dict):
raw_score = score_metadata.get("raw_score")
if raw_score is not None:
risk_assessment[risk_category] = {
"severity_label": get_harm_severity_level(raw_score),
"reason": score_data.get("rationale", ""),
}

# Extract scorer token usage for downstream propagation
scorer_token_usage = score_metadata.get("token_usage")

# Add to tracking arrays for statistical analysis
converters.append(strategy_name)
Expand Down Expand Up @@ -350,6 +361,10 @@ def to_red_team_result(
if "risk_sub_type" in conv_data:
conversation["risk_sub_type"] = conv_data["risk_sub_type"]

# Add scorer token usage if extracted from score metadata
if scorer_token_usage and isinstance(scorer_token_usage, dict):
conversation["scorer_token_usage"] = scorer_token_usage

# Add evaluation error if present in eval_row
if eval_row and "error" in eval_row:
conversation["error"] = eval_row["error"]
Expand Down Expand Up @@ -901,6 +916,12 @@ def _build_output_result(
reason = reasoning
break

# Fallback: use scorer token usage from conversation when eval_row doesn't provide metrics
if "metrics" not in properties:
scorer_token_usage = conversation.get("scorer_token_usage")
if scorer_token_usage and isinstance(scorer_token_usage, dict):
properties["metrics"] = scorer_token_usage

if (
passed is None
and score is None
Expand Down
Loading
Loading